In [1]:
import pandas as pd
import numpy as np
import torch
from torch import nn 
from torch.utils.data import DataLoader     # 데이터로더는 데이터셋을 iterable하게 감싸는 역할
from torchvision import datasets            # 데이터셋은 샘플과 정답을 저장함
from torchvision.transforms import ToTensor
import clip
from PIL import Image

## CLIP model

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
image_title = 'A horse in the space.png'

prompts = [
    "A horse in the space",
    "A dog in the space",
    "A bear in the space",
    "A person in the space",
    "A horse in the park",
    "A dog in the park",
    "A bear in the park",
    "A person in the park",
]

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize(prompts).to(device)

In [30]:
model(image, text)

(tensor([[30.6719, 24.0625, 22.5156, 23.5312, 23.3750, 16.5312, 14.7188, 17.2344]],
        device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>),
 tensor([[30.6719],
         [24.0625],
         [22.5156],
         [23.5312],
         [23.3750],
         [16.5312],
         [14.7188],
         [17.2344]], device='cuda:0', dtype=torch.float16, grad_fn=<TBackward0>))

## Our Model

In [60]:
def get_image_features_per_layer(clip_model, image):
    features = [torch.zeros([image.shape[0], 768]) for i in range(clip_model.visual.transformer.layers)]

    # image -> tokens
    x = clip_model.visual.conv1(image.type(clip_model.visual.conv1.weight.dtype))
    x = x.reshape(x.shape[0], x.shape[1], -1)   # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)                      # shape = [*, grid ** 2, width]
    x = torch.cat([clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + clip_model.visual.positional_embedding.to(x.dtype)
    x = clip_model.visual.ln_pre(x)

    # tokens -> transformer -> feature_embeddings
    x = x.permute(1, 0, 2)          # NLD -> LND

    for i in range(clip_model.visual.transformer.layers):
        x = clip_model.visual.transformer.resblocks[i](x)
        tmp = x.permute(1, 0, 2)    # LND -> NLD
        tmp = tmp[:, 0, :]
        features[i].copy_(tmp)

    tensor = torch.stack(features, dim=0).detach()
    # tensor.requires_grad = False

    return tensor.to(device)

In [42]:
def get_text_features_per_layer(clip_model, text):
    features = [torch.zeros([*text.shape, 512]) for i in range(clip_model.transformer.layers)]

    #%% clip_model.encode_image

    x = clip_model.token_embedding(text).type(clip_model.dtype)  # [batch_size, n_ctx, d_clip_model]

    x = x + clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)          # NLD -> LND

    for i in range(clip_model.transformer.layers):
        x = clip_model.transformer.resblocks[i](x)
        tmp = x.permute(1, 0, 2)    # LND -> NLD
        features[i].copy_(tmp)

    tensor = torch.stack(features, dim=0).detach()
    # tensor.requires_grad = False

    return tensor.to(device)

#### lab

In [43]:
model.visual.ln_post(get_image_features_per_layer(model, image).to(device)).shape

torch.Size([12, 1, 768])

In [44]:
model.ln_final

LayerNorm((512,), eps=1e-05, elementwise_affine=True)

In [34]:
model.ln_final(get_text_features_per_layer(model, text).to(device)).shape

torch.Size([12, 8, 77, 512])

In [35]:
model.visual.proj.shape

torch.Size([768, 512])

In [36]:
torch.rand(model.visual.proj.shape)

tensor([[0.7060, 0.3264, 0.8114,  ..., 0.3201, 0.6791, 0.0149],
        [0.8198, 0.7222, 0.8461,  ..., 0.4048, 0.2797, 0.4489],
        [0.4857, 0.4051, 0.6350,  ..., 0.2085, 0.0135, 0.5073],
        ...,
        [0.2146, 0.3847, 0.2523,  ..., 0.9563, 0.4808, 0.8936],
        [0.9526, 0.3367, 0.8129,  ..., 0.4689, 0.7941, 0.3421],
        [0.1183, 0.8862, 0.1054,  ..., 0.7301, 0.8447, 0.9090]])

In [37]:
proj = torch.stack([model.visual.proj for _ in range(12)])
proj.shape

torch.Size([12, 768, 512])

In [38]:
tmp = [model.visual.proj for _ in range(12)]
tmp[0] is tmp[1], proj[0] is proj[1]

(True, False)

In [39]:
print(proj[0][0][0].item(), proj[1][0][0].item())
proj[0][0][0] = -1
print(proj[0][0][0].item(), proj[1][0][0].item())
# light copy

-0.0026264190673828125 -0.0026264190673828125
-1.0 -0.0026264190673828125


In [42]:
tmp = get_image_features_per_layer(model, image).to(device)
tmp1 = model.visual.ln_post(tmp)[0]
tmp2 = model.visual.ln_post(tmp[0])
(tmp1 == tmp2).sum()

tensor(768, device='cuda:0')

#### inference

In [157]:
# Define model
class OurCLIP(nn.Module):        
    def __init__(self, clip_model, # use pre-trained clip model
                 use_one_ln1=True, use_one_ln2=True, 
                 use_one_projection1=True, use_one_projection2=True,
                 trainable_ln1=False, trainable_ln2=False, 
                 trainable_projection1=False, trainable_projection2=False,
                 threshold = 20): 
        super().__init__()

        self.dtype = clip_model.dtype
        self.threshold = threshold

        ####################### 미구현
        if use_one_ln1:
            self.ln_post = clip_model.visual.ln_post
        else: pass
        if use_one_ln2:
            self.ln_final = clip_model.ln_final
        else: pass
        ########################

        if use_one_projection1:
            clip_model.visual.proj.requires_grad = True
            self.visual_projection = torch.stack([clip_model.visual.proj for _ in range(12)]).type(torch.float)
            if not trainable_projection1:
                self.visual_projection.detach().type(torch.float)
        else:
            self.visual_projection = torch.stack([*[torch.rand(clip_model.visual.proj.shape) for _ in range(11)],
                                                   clip_model.visual.proj]).type(torch.float)
        
        if use_one_projection2:
            clip_model.text_projection.requires_grad = True
            self.textual_projection = torch.stack([clip_model.text_projection for _ in range(12)])
            if not trainable_projection2:
                self.visual_projection.detach()
        else:
            self.textual_projection = torch.stack([*[torch.rand(clip_model.text_projection.shape) for _ in range(11)], 
                                                   clip_model.text_projection])


    def forward(self, image_features, text, text_features):    
        # image_features = image_features.to(device)
        # text_features = text_features.to(device)

        image_features = self.ln_post(image_features)                   # (12, batch_size, 768)
        image_embeddings = image_features @ self.visual_projection      # (12, batch_size, 512)  <- we'll use it

        text_features = self.ln_final(text_features).type(self.dtype)   # (12, seq_len, 77, 512)
        text_embeddings = text_features[:, torch.arange(text_features.shape[1]), text.argmax(dim=-1)] @ self.textual_projection  
                                                                        # (12, seq_len, 512)  <- we'll use it
        print(text_embeddings.shape, image_embeddings.shape)
        scores = [[[], [], []] for _ in range(image_features.shape[1])]

        for k in range(image_features.shape[1]):
            for _class in range(text_features.shape[1]):
                # print(image_embeddings[0].shape)
                # print(text_embeddings[0][[_class]].T.shape)
                # print((image_embeddings[0] @ text_embeddings[0][[_class]].T.type(torch.float)).item())
                img_similarity_by_layer = [[(image_embeddings[i][[k]] @ text_embeddings[j][[_class]].T.type(torch.float)).item()
                                            for j in range(12)] for i in range(12)]
                scores[k][0].append(np.max(img_similarity_by_layer))
                scores[k][1].append(np.mean(img_similarity_by_layer))
                scores[k][2].append(sum(list(map(lambda x: sum([y > self.threshold for y in x]), img_similarity_by_layer))))

        return scores

In [158]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

model = OurCLIP(clip_model, trainable_projection1=True, trainable_projection2=True).to(device)  # model을 initialize하는 부분
print(model)

OurCLIP(
  (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)


### lab

In [74]:
image_title = 'A horse in the space.png'

prompts = [
    "A horse in the space",
    "A dog in the space",
    "A bear in the space",
    "A person in the space",
    "A horse in the park",
    "A dog in the park",
    "A bear in the park",
    "A person in the park",
]

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize(prompts).to(device)

In [75]:
image_features = get_image_features_per_layer(clip_model, image)
text_features = get_text_features_per_layer(clip_model, text)

In [49]:
model(image_features, text_features)

[[33.82638168334961,
  26.978994369506836,
  26.948450088500977,
  30.468639373779297,
  25.630268096923828,
  20.795486450195312,
  22.559879302978516,
  23.461559295654297],
 [12.221618531478775,
  12.280218217107985,
  12.698671211798986,
  14.021323535177443,
  11.308028957910008,
  10.396428364846441,
  11.075213006801075,
  10.72875845928987],
 [14, 13, 13, 13, 6, 1, 6, 11]]

In [50]:
clip_model.text_projection.shape

torch.Size([512, 512])

## Train

### Load-Dataset

In [51]:
import deeplake
from PIL import Image
import numpy as np
import os, time
import torch
from torchvision import transforms, models

In [52]:
# Connect to the training and testing datasets
ds_train = deeplake.load('hub://activeloop/pacs-train')
ds_test = deeplake.load('hub://activeloop/pacs-test')

tform = transforms.Compose([
    transforms.RandomRotation(20), # Image augmentation
    transforms.ToTensor(), # Must convert to pytorch tensor for subsequent operations to run
    transforms.Normalize([0.5], [0.5]),
])
batch_size = 4

hub://activeloop/pacs-train loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/pacs-train
hub://activeloop/pacs-test loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/pacs-test


In [143]:
# Specity the loss function and optimizer
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.1)

In [144]:
pacs_class = [
    'a dog', 'an elephant', 'a giraffe', 'a guitar', 'a horse', 'a house', 'a person'
]

def prompt(idx):
    return f"An image of {pacs_class[idx]}"

prompts = clip.tokenize([prompt(x) for x in range(7)]).to(device)
print(prompts.shape)

torch.Size([7, 77])


In [145]:
# Since torchvision transforms expect PIL images, we use the 'pil' decode_method for the 'images' tensor. This is much faster than running ToPILImage inside the transform
train_loader = ds_train.pytorch(num_workers = 0, shuffle = True, 
                                transform = {'images': tform, 'labels': None}, 
                                batch_size = batch_size, decode_method = {'images': 'pil'})
test_loader = ds_test.pytorch(num_workers = 0, transform = {'images': tform, 'labels': None}, 
                                batch_size = batch_size, decode_method = {'images': 'pil'})

In [129]:
def train_one_epoch(model, optimizer, data_loader, device, prompts):

    model.train()

    # Zero the performance stats for each epoch
    running_loss = 0.0
    start_time = time.time()
    total = 0
    correct = 0
    
    for i, data in enumerate(data_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs = data['images']
        labels = torch.squeeze(data['labels'])

        inputs = inputs.to(device)
        labels = labels.to(device)

        image_features = get_image_features_per_layer(clip_model, inputs)
        text_features = get_text_features_per_layer(clip_model, prompts)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        ###################################### output이 Tensor가 아닌 list인데 괜찮은지
        scores = model(image_features, prompts, text_features)
        outputs = torch.Tensor(list(map(lambda x: x[0], scores))).to(device)   # 여기서 tensor 객체를 새롭게 정의하는 것이 의미있는가
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        ######################################
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
    
        # Print performance statistics
        running_loss += loss.item()
        if i % 10 == 0:    # print every 10 batches
            batch_time = time.time()
            speed = (i+1)/(batch_time-start_time)
            print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %
                  (i, running_loss, speed, accuracy))

            running_loss = 0.0
            total = 0
            correct = 0
            break

    
def test_model(model, data_loader):

    model.eval()

    start_time = time.time()
    total = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs = data['images']
            labels = torch.squeeze(data['labels'])

            inputs = inputs.to(device)
            labels = labels.to(device)

            image_features = get_image_features_per_layer(clip_model, inputs)
            text_features = get_text_features_per_layer(clip_model, prompts)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            scores = model(image_features, prompts, text_features)
            outputs = torch.Tensor(list(map(lambda x: x[0], scores))).to(device)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
            
        print('Finished Testing')
        print('Testing accuracy: %.1f %%' %(accuracy))

### lab

In [63]:
inputs.shape, prompts.shape
get_image_features_per_layer(clip_model, inputs).shape, get_text_features_per_layer(clip_model, prompts).shape

(torch.Size([12, 4, 768]), torch.Size([12, 7, 77, 512]))

In [107]:
image_features = get_image_features_per_layer(clip_model, inputs)
text_features = get_text_features_per_layer(clip_model, prompts)

In [116]:
scores = model(image_features, prompts, text_features)

torch.Size([12, 7, 512]) torch.Size([12, 4, 512])


In [126]:
outputs = torch.Tensor(list(map(lambda x: x[0], scores))).to(device)
outputs, labels

(tensor([[33.2395, 28.6383, 27.9835, 37.3564, 32.7737, 30.8021, 35.0590],
         [33.6642, 28.3727, 27.1275, 25.2947, 40.4309, 29.3082, 34.7931],
         [33.8334, 29.4328, 28.4747, 37.0757, 34.1582, 29.9939, 36.0595],
         [33.4167, 37.2045, 26.9609, 24.3298, 32.5465, 29.5641, 35.3329]],
        device='cuda:0'),
 tensor([3, 4, 3, 1], device='cuda:0'))

In [127]:
loss = criterion(outputs, labels)
loss

tensor(0.1682, device='cuda:0')

In [128]:
torch.max(outputs.data, 1)

torch.return_types.max(
values=tensor([37.3564, 40.4309, 37.0757, 37.2045], device='cuda:0'),
indices=tensor([3, 4, 3, 1], device='cuda:0'))

### inference

In [159]:
num_epochs = 3
for epoch in range(num_epochs):  # loop over the dataset multiple times
    print("------------------ Training Epoch {} ------------------".format(epoch+1))
    train_one_epoch(model, optimizer, train_loader, device, prompts)

    test_model(model, test_loader)

print('Finished Training')

------------------ Training Epoch 1 ------------------
torch.Size([12, 7, 512]) torch.Size([12, 4, 512])


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [137]:
model.visual_projection.requires_grad_(True)

tensor([[[-2.6264e-03,  5.0962e-05,  2.7496e-02,  ..., -1.0025e-02,
          -1.2222e-02,  5.8403e-03],
         [-1.9852e-02,  7.1182e-03,  8.9788e-04,  ...,  1.1528e-02,
          -1.9485e-02, -8.0185e-03],
         [-8.6288e-03,  1.9226e-03, -2.1725e-03,  ...,  3.9330e-03,
          -1.1269e-02,  1.5345e-03],
         ...,
         [-1.1993e-02,  1.2955e-02,  2.5848e-02,  ..., -9.8038e-03,
          -4.2076e-03,  1.5211e-04],
         [-1.2871e-02, -9.5673e-03, -1.0826e-02,  ..., -7.0610e-03,
          -4.3182e-03, -4.9353e-04],
         [-4.4098e-03,  3.3588e-03, -1.2054e-02,  ...,  6.1073e-03,
           3.9940e-03, -3.0861e-03]],

        [[-2.6264e-03,  5.0962e-05,  2.7496e-02,  ..., -1.0025e-02,
          -1.2222e-02,  5.8403e-03],
         [-1.9852e-02,  7.1182e-03,  8.9788e-04,  ...,  1.1528e-02,
          -1.9485e-02, -8.0185e-03],
         [-8.6288e-03,  1.9226e-03, -2.1725e-03,  ...,  3.9330e-03,
          -1.1269e-02,  1.5345e-03],
         ...,
         [-1.1993e-02,  1