In [17]:
import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch.utils.data import DataLoader     # 데이터로더는 데이터셋을 iterable하게 감싸는 역할
from torchvision import datasets            # 데이터셋은 샘플과 정답을 저장함
from torchvision.transforms import ToTensor
import clip
from PIL import Image

## CLIP model

In [18]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [25]:
image_title = 'A horse in the space.png'

prompts = [
    "A horse in the space",
    "A dog in the space",
    "A bear in the space",
    "A person in the space",
    "A horse in the park",
    "A dog in the park",
    "A bear in the park",
    "A person in the park",
]

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize(prompts).to(device)

In [30]:
model(image, text)

(tensor([[30.6719, 24.0625, 22.5156, 23.5312, 23.3750, 16.5312, 14.7188, 17.2344]],
        device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>),
 tensor([[30.6719],
         [24.0625],
         [22.5156],
         [23.5312],
         [23.3750],
         [16.5312],
         [14.7188],
         [17.2344]], device='cuda:0', dtype=torch.float16, grad_fn=<TBackward0>))

## Our Model

In [27]:
def get_image_features_per_layer(clip_model, image):
    features = [torch.zeros([1, 768]) for i in range(clip_model.visual.transformer.layers)]

    # image -> tokens
    x = clip_model.visual.conv1(image.type(clip_model.visual.conv1.weight.dtype))
    x = x.reshape(x.shape[0], x.shape[1], -1)   # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)                      # shape = [*, grid ** 2, width]
    x = torch.cat([clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + clip_model.visual.positional_embedding.to(x.dtype)
    x = clip_model.visual.ln_pre(x)

    # tokens -> transformer -> feature_embeddings
    x = x.permute(1, 0, 2)          # NLD -> LND

    for i in range(clip_model.visual.transformer.layers):
        x = clip_model.visual.transformer.resblocks[i](x)
        tmp = x.permute(1, 0, 2)    # LND -> NLD
        tmp = tmp[:, 0, :]
        features[i].copy_(tmp)

    tensor = torch.stack(features, dim=0).detach()
    tensor.to(device)
    # tensor.requires_grad = False

    return tensor

In [28]:
def get_text_features_per_layer(clip_model, text):
    features = [torch.zeros([*text.shape, 512]) for i in range(clip_model.transformer.layers)]

    #%% clip_model.encode_image

    x = clip_model.token_embedding(text).type(clip_model.dtype)  # [batch_size, n_ctx, d_clip_model]

    x = x + clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)          # NLD -> LND

    for i in range(clip_model.transformer.layers):
        x = clip_model.transformer.resblocks[i](x)
        tmp = x.permute(1, 0, 2)    # LND -> NLD
        features[i].copy_(tmp)

    tensor = torch.stack(features, dim=0).detach()
    tensor.to(device)
    # tensor.requires_grad = False

    return tensor

#### lab

In [32]:
model.visual.ln_post(get_image_features_per_layer(model, image).to(device)).shape

torch.Size([12, 1, 768])

In [33]:
model.ln_final

LayerNorm((512,), eps=1e-05, elementwise_affine=True)

In [34]:
model.ln_final(get_text_features_per_layer(model, text).to(device)).shape

torch.Size([12, 8, 77, 512])

In [35]:
model.visual.proj.shape

torch.Size([768, 512])

In [36]:
torch.rand(model.visual.proj.shape)

tensor([[0.7060, 0.3264, 0.8114,  ..., 0.3201, 0.6791, 0.0149],
        [0.8198, 0.7222, 0.8461,  ..., 0.4048, 0.2797, 0.4489],
        [0.4857, 0.4051, 0.6350,  ..., 0.2085, 0.0135, 0.5073],
        ...,
        [0.2146, 0.3847, 0.2523,  ..., 0.9563, 0.4808, 0.8936],
        [0.9526, 0.3367, 0.8129,  ..., 0.4689, 0.7941, 0.3421],
        [0.1183, 0.8862, 0.1054,  ..., 0.7301, 0.8447, 0.9090]])

In [37]:
proj = torch.stack([model.visual.proj for _ in range(12)])
proj.shape

torch.Size([12, 768, 512])

In [38]:
tmp = [model.visual.proj for _ in range(12)]
tmp[0] is tmp[1], proj[0] is proj[1]

(True, False)

In [39]:
print(proj[0][0][0].item(), proj[1][0][0].item())
proj[0][0][0] = -1
print(proj[0][0][0].item(), proj[1][0][0].item())
# light copy

-0.0026264190673828125 -0.0026264190673828125
-1.0 -0.0026264190673828125


In [41]:
(get_image_features_per_layer(model, image).to(device) @ proj).shape

RuntimeError: expected scalar type Half but found Float

In [42]:
tmp = get_image_features_per_layer(model, image).to(device)
tmp1 = model.visual.ln_post(tmp)[0]
tmp2 = model.visual.ln_post(tmp[0])
(tmp1 == tmp2).sum()

tensor(768, device='cuda:0')

#### inference

In [43]:
# Define model
class OurCLIP(nn.Module):        
    def __init__(self, clip_model, # use pre-trained clip model
                 use_one_ln1=True, use_one_ln2=True, 
                 use_one_projection1=True, use_one_projection2=True,
                 trainable_ln1=False, trainable_ln2=False, 
                 trainable_projection1=False, trainable_projection2=False,
                 threshold = 20): 
        super().__init__()

        self.dtype = clip_model.dtype
        self.threshold = threshold

        ####################### 미구현
        if use_one_ln1:
            self.ln_post = clip_model.visual.ln_post
        else: pass
        if use_one_ln2:
            self.ln_final = clip_model.ln_final
        else: pass
        ########################

        if use_one_projection1:
            self.visual_projection = torch.stack([clip_model.visual.proj for _ in range(12)]).type(torch.float)
            if not trainable_projection1:
                self.visual_projection.detach().type(torch.float)
        else:
            self.visual_projection = torch.stack([*[torch.rand(clip_model.visual.proj.shape) for _ in range(11)],
                                                   clip_model.visual.proj]).type(torch.float)
        
        if use_one_projection2:
            self.textual_projection = torch.stack([clip_model.text_projection for _ in range(12)])
            if not trainable_projection2:
                self.visual_projection.detach()
        else:
            self.textual_projection = torch.stack([*[torch.rand(clip_model.text_projection.shape) for _ in range(11)], 
                                                   clip_model.text_projection])


    def forward(self, image_features, text_features):    
        image_features = image_features.to(device)
        text_features = text_features.to(device)

        image_features = self.ln_post(image_features)                   # (12, 1, 768)
        image_embeddings = image_features @ self.visual_projection      # (12, 1, 512)  <- we'll use it

        text_features = self.ln_final(text_features).type(self.dtype)   # (12, 8, 77, 512)
        text_embeddings = text_features[:, torch.arange(text_features.shape[1]), text.argmax(dim=-1)] @ self.textual_projection  
                                                                        # (12, 8, 512)  <- we'll use it

        scores = [[], [], []]

        for _class in range(text_features.shape[1]):
            # print(image_embeddings[0].shape)
            # print(text_embeddings[0][[_class]].T.shape)
            # print((image_embeddings[0] @ text_embeddings[0][[_class]].T.type(torch.float)).item())
            img_similarity_by_layer = [[(image_embeddings[i] @ text_embeddings[j][[_class]].T.type(torch.float)).item()
                                         for j in range(12)] for i in range(12)]
            scores[0].append(np.max(img_similarity_by_layer))
            scores[1].append(np.mean(img_similarity_by_layer))
            scores[2].append(sum(list(map(lambda x: sum([y > self.threshold for y in x]), img_similarity_by_layer))))

        return scores

In [44]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

model = OurCLIP(clip_model).to(device)  # model을 initialize하는 부분
print(model)

OurCLIP(
  (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)


In [45]:
image_title = 'A horse in the space.png'

prompts = [
    "A horse in the space",
    "A dog in the space",
    "A bear in the space",
    "A person in the space",
    "A horse in the park",
    "A dog in the park",
    "A bear in the park",
    "A person in the park",
]

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize(prompts).to(device)

In [46]:
image_features = get_image_features_per_layer(clip_model, image)
text_features = get_text_features_per_layer(clip_model, text)

In [47]:
model(image_features, text_features)

[[33.82638168334961,
  26.978994369506836,
  26.948450088500977,
  30.468639373779297,
  25.630268096923828,
  20.795486450195312,
  22.559879302978516,
  23.461559295654297],
 [12.221618531478775,
  12.280218217107985,
  12.698671211798986,
  14.021323535177443,
  11.308028957910008,
  10.396428364846441,
  11.075213006801075,
  10.72875845928987],
 [14, 13, 13, 13, 6, 1, 6, 11]]

In [48]:
clip_model.text_projection.shape

torch.Size([512, 512])

## Train

### Load-Dataset

In [49]:
import deeplake
from PIL import Image
import numpy as np
import os, time
import torch
from torchvision import transforms, models

In [50]:
# Connect to the training and testing datasets
ds_train = deeplake.load('hub://activeloop/pacs-train')
ds_test = deeplake.load('hub://activeloop/pacs-test')

tform = transforms.Compose([
    transforms.RandomRotation(20), # Image augmentation
    transforms.ToTensor(), # Must convert to pytorch tensor for subsequent operations to run
    transforms.Normalize([0.5], [0.5]),
])
batch_size = 4

# Since torchvision transforms expect PIL images, we use the 'pil' decode_method for the 'images' tensor. This is much faster than running ToPILImage inside the transform
train_loader = ds_train.pytorch(num_workers = 0, shuffle = True, 
                                transform = {'images': tform, 'labels': None}, 
                                batch_size = batch_size, decode_method = {'images': 'pil'})
test_loader = ds_test.pytorch(num_workers = 0, transform = {'images': tform, 'labels': None}, 
                                batch_size = batch_size, decode_method = {'images': 'pil'})

hub://activeloop/pacs-train loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/pacs-train
hub://activeloop/pacs-test loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/pacs-test


In [56]:
# Specity the loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.1)

In [59]:
pacs_class = [
    'a dog', 'an elephant', 'a giraffe', 'a guitar', 'a horse', 'a house', 'a person'
]

def prompt(idx):
    return f"An image of {pacs_class[idx]}"

prompts = clip.tokenize([prompt(x) for x in range(7)]).to(device)
prompts.shape

torch.Size([7, 77])