<a href="https://colab.research.google.com/github/zhangluustb/transformers_notebook/blob/main/colab_taiyi_clip_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install clip transformers

# Zero-Shot Classification

In [5]:
from PIL import Image
import requests
import clip
import torch
from transformers import BertForSequenceClassification, BertConfig, BertTokenizer
from transformers import CLIPProcessor, CLIPModel
import numpy as np

In [9]:
query_texts = ["一只猫", "一只狗",'两只猫', '两只老虎','一只老虎']  # 这里是输入文本的，可以随意替换。
# 加载Taiyi 中文 text encoder
text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese")
text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").eval()
text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']
text

tensor([[ 101,  671, 1372, 4344,  102,    0],
        [ 101,  671, 1372, 4318,  102,    0],
        [ 101,  697, 1372, 4344,  102,    0],
        [ 101,  697, 1372, 5439, 5988,  102],
        [ 101,  671, 1372, 5439, 5988,  102]])

In [7]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"  # 这里可以换成任意图片的url
# 加载CLIP的image encoder
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")  
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors="pt")

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

In [11]:
with torch.no_grad():
    image_features = clip_model.get_image_features(**image)
    text_features = text_encoder(text).logits
    # 归一化
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    # 计算余弦相似度 logit_scale是尺度系数
    logit_scale = clip_model.logit_scale.exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    print(np.around(probs, 3))

[[0.01 0.   0.99 0.   0.  ]]


# Zero-Shot Text-to-Image Retrieval

In [12]:
# step1 get all image logits
# step2 iter text input get retrieval topk acc