In [2]:
import torch 
from PIL import Image

import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models())  
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
cuda


In [3]:
model, preprocess = load_from_name("RN50", device=device, download_root='./datapath')

100%|███████████████████████████████████████| 294M/294M [00:11<00:00, 27.3MiB/s]


Loading vision model config from g:\code\github\Chinese-CLIP\cn_clip\clip\model_configs\RN50.json
Loading text model config from g:\code\github\Chinese-CLIP\cn_clip\clip\model_configs\RBT3-chinese.json
Model info {'embed_dim': 1024, 'image_resolution': 224, 'vision_layers': [3, 4, 6, 3], 'vision_width': 64, 'vision_patch_size': None, 'vocab_size': 21128, 'text_attention_probs_dropout_prob': 0.1, 'text_hidden_act': 'gelu', 'text_hidden_dropout_prob': 0.1, 'text_hidden_size': 768, 'text_initializer_range': 0.02, 'text_intermediate_size': 3072, 'text_max_position_embeddings': 512, 'text_num_attention_heads': 12, 'text_num_hidden_layers': 3, 'text_type_vocab_size': 2}


In [7]:
model, preprocess = load_from_name(
    "./datapath/pretrained_weights/clip_cn_vit-b-16.pt", 
    device=device,
    vision_model_name="ViT-B-16",
    text_model_name="RoBERTa-wwm-ext-base-chinese",
    input_resolution=224,
)

Loading vision model config from g:\code\github\Chinese-CLIP\cn_clip\clip\model_configs\ViT-B-16.json
Loading text model config from g:\code\github\Chinese-CLIP\cn_clip\clip\model_configs\RoBERTa-wwm-ext-base-chinese.json
Model info {'embed_dim': 512, 'image_resolution': 224, 'vision_layers': 12, 'vision_width': 768, 'vision_patch_size': 16, 'vocab_size': 21128, 'text_attention_probs_dropout_prob': 0.1, 'text_hidden_act': 'gelu', 'text_hidden_dropout_prob': 0.1, 'text_hidden_size': 768, 'text_initializer_range': 0.02, 'text_intermediate_size': 3072, 'text_max_position_embeddings': 512, 'text_num_attention_heads': 12, 'text_num_hidden_layers': 12, 'text_type_vocab_size': 2}


In [8]:
model.eval()

image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    # 对特征进行归一化，请使用归一化后的图文特征用于下游任务
    image_features /= image_features.norm(dim=-1, keepdim=True) 
    text_features /= text_features.norm(dim=-1, keepdim=True)    

    logits_per_image, logits_per_text = model.get_similarity(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]

Label probs: [[1.251e-03 5.490e-02 6.909e-04 9.434e-01]]
