In [64]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image_title = 'A dog is sitting on a couch with a stuffed animal.png'

image = preprocess(Image.open(image_title)).unsqueeze(0).to(device)
text = clip.tokenize("A dog is sitting on a couch with a stuffed animal.png").to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    # logits_per_image, logits_per_text = model(image, text)
    # probs = logits_per_image.softmax(dim=-1).cpu().numpy()

# print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

In [3]:
clip.tokenize("A dog is sitting on a couch with a stuffed animal.png").shape

torch.Size([1, 77])

In [4]:
preprocess(Image.open(image_title)).shape

torch.Size([3, 224, 224])

In [27]:
model.visual.transformer.resblocks[0].ln_2

LayerNorm((768,), eps=1e-05, elementwise_affine=True)

In [32]:
model.transformer.resblocks[0].ln_2

LayerNorm((512,), eps=1e-05, elementwise_affine=True)

In [41]:
model(image, text)

(tensor([[28.7812]], device='cuda:0', dtype=torch.float16,
        grad_fn=<MmBackward0>),
 tensor([[28.7812]], device='cuda:0', dtype=torch.float16, grad_fn=<TBackward0>))

In [43]:
image_features.shape

torch.Size([1, 512])

In [44]:
text_features.shape

torch.Size([1, 512])

In [68]:
model.visual.proj.shape

torch.Size([768, 512])

In [79]:
model.visual

VisionTransformer(
  (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
  (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (transformer): Transformer(
    (resblocks): Sequential(
      (0): ResidualAttentionBlock(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): QuickGELU()
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (1): ResidualAttentionBlock(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise

In [116]:
img_embeddings = [torch.zeros([1, 512]) for i in range(model.visual.transformer.layers)]

#%% model.encode_image

# image -> tokens
x = model.visual.conv1(image.type(model.visual.conv1.weight.dtype))
x = x.reshape(x.shape[0], x.shape[1], -1)   # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1)                      # shape = [*, grid ** 2, width]
x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
x = x + model.visual.positional_embedding.to(x.dtype)
x = model.visual.ln_pre(x)

# tokens -> transformer -> feature_embeddings
x = x.permute(1, 0, 2)  # NLD -> LND

# x = model.visual.transformer(x)
#%% model.transformer(x)

# x = model.visual.transformer.resblocks(x)
#%% model.transformer.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

# x = model.visual.transformer.resblocks[0](x)    # 첫번째 레이어
# img_embeddings[0].copy_(x)
# ...
# x = model.visual.transformer.resblocks[11](x)   # 마지막 레이어
# img_embeddings[11].copy_(x)

for i in range(model.visual.transformer.layers):
    x = model.visual.transformer.resblocks[i](x)
    tmp = x.permute(1, 0, 2)
    tmp = model.visual.ln_post(tmp[:, 0, :])
    if model.visual.proj is not None:
        tmp = tmp @ model.visual.proj
    img_embeddings[i].copy_(tmp)

x = x.permute(1, 0, 2)  # LND -> NLD
print(x.shape)  

x = model.visual.ln_post(x[:, 0, :])    # [CLS] token의 임베딩을 사용
print(x.shape)  

if model.visual.proj is not None:
    x = x @ model.visual.proj
print(x.shape)

torch.Size([1, 50, 768])
torch.Size([1, 768])
torch.Size([1, 512])


In [115]:
img_embeddings[0].shape

torch.Size([1, 512])

In [139]:
txt_embeddings = [torch.zeros([1, 512]) for i in range(model.transformer.layers)]

#%% model.encode_image

x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

x = x + model.positional_embedding.type(model.dtype)
x = x.permute(1, 0, 2)  # NLD -> LND

# x = model.transformer(x)
#%% model.transformer(x)    # 위와 거의 동일하게 처리한다

for i in range(model.transformer.layers):
    x = model.transformer.resblocks[i](x)
    tmp = x.permute(1, 0, 2)
    tmp = model.ln_final(tmp).type(model.dtype)
    tmp = tmp[torch.arange(tmp.shape[0]), text.argmax(dim=-1)] @ model.text_projection
    txt_embeddings[i].copy_(tmp)


x = x.permute(1, 0, 2)  # LND -> NLD
x = model.ln_final(x).type(model.dtype)
print(x.shape)
# print(x[[0], 1].shape)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection    
    # x.shape[0] : 들어온 단어 토큰의 개수
    # x[[(which_word_token)], (which_end_token)] : [CLS] token이 아니라 end token의 임베딩을 사용
    # text.argmax(dim=-1) : end token의 위치
print(x.shape)


torch.Size([1, 77, 512])
torch.Size([1, 512])


In [121]:
torch.arange(x.shape[0]), text.argmax(dim=-1)

(tensor([0]), tensor([14], device='cuda:0'))

In [126]:
model.text_projection.shape

torch.Size([512, 512])

In [137]:
import numpy as np
tmp = np.array([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
tmp[[0], 1]

array([[4, 5, 6]])

In [140]:
txt_embeddings[0].shape

torch.Size([1, 512])

In [145]:
#%% model.forward(image, text)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)

# cosine similarity as logits
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

# shape = [global_batch_size, global_batch_size]
print(logits_per_image.item(), logits_per_text.item())

#%% (+) get probabilities

with torch.no_grad():
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)

28.78125 28.78125
Label probs: [[1.]]


In [163]:
print(image_features[0, :5])
print(text_features[0, :5])

tensor([ 0.0232,  0.0191, -0.0026,  0.0169,  0.0393], device='cuda:0',
       dtype=torch.float16)
tensor([ 0.0200,  0.0095, -0.0619, -0.0078, -0.0046], device='cuda:0',
       dtype=torch.float16)
