In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tqdm
import torch
from IPython.display import display

import clip
from clip.model import CLIP
from utils.init_model import model, preprocess, device, load_model, load

loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt


In [2]:
embed_dim = 512
image_resolution = 224
vision_layers = 12
vision_width = 768
vision_patch_size = 32
context_length = 77
vocab_size = 49408
transformer_width = 512
transformer_heads = 8
transformer_layers = 12

In [3]:
vision_layers = 3
transformer_layers = 3

In [5]:
model = load('ViT-B/32', device=device, jit=False, vision_layers=vision_layers, transformer_layers=transformer_layers, load_state_dict=False)[0]

loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt


In [4]:
model = CLIP(
  embed_dim,
  image_resolution,
  vision_layers,
  vision_width,
  vision_patch_size,
  context_length,
  vocab_size,
  transformer_width,
  transformer_heads,
  transformer_layers,
).to(device)

In [6]:
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 59,115,009
Input resolution: 224
Context length: 77
Vocab size: 49408


In [7]:
text = "a photo of a cat"
tokenized_text = clip.tokenize([text]).to(device)
print(tokenized_text.shape)

torch.Size([1, 77])


In [8]:
for name, param in model.named_parameters():
    print(name)

positional_embedding
text_projection
logit_scale
visual.class_embedding
visual.positional_embedding
visual.proj
visual.conv1.weight
visual.ln_pre.weight
visual.ln_pre.bias
visual.transformer.resblocks.0.attn.in_proj_weight
visual.transformer.resblocks.0.attn.in_proj_bias
visual.transformer.resblocks.0.attn.out_proj.weight
visual.transformer.resblocks.0.attn.out_proj.bias
visual.transformer.resblocks.0.ln_1.weight
visual.transformer.resblocks.0.ln_1.bias
visual.transformer.resblocks.0.mlp.c_fc.weight
visual.transformer.resblocks.0.mlp.c_fc.bias
visual.transformer.resblocks.0.mlp.c_proj.weight
visual.transformer.resblocks.0.mlp.c_proj.bias
visual.transformer.resblocks.0.ln_2.weight
visual.transformer.resblocks.0.ln_2.bias
visual.transformer.resblocks.1.attn.in_proj_weight
visual.transformer.resblocks.1.attn.in_proj_bias
visual.transformer.resblocks.1.attn.out_proj.weight
visual.transformer.resblocks.1.attn.out_proj.bias
visual.transformer.resblocks.1.ln_1.weight
visual.transformer.resblo

In [9]:
embedded_text = model.encode_text(tokenized_text)

In [10]:
embedded_text.shape

torch.Size([1, 512])