In [None]:
!pip install ftfy
!git clone https://github.com/openai/CLIP.git
%cd /content/CLIP

In [None]:
import torch
import clip
from PIL import Image

class TextTransformer(torch.nn.Module):
  def __init__(self, clip_model):
    super().__init__()
    self.clip_model = clip_model

  def forward(self, x: torch.Tensor):
    return self.clip_model.encode_text(x)


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize("a cat").to(device)

with torch.no_grad():
  image_feature = model.encode_image(image)
  text_feature = model.encode_text(text)

  logits_per_image, logits_per_text = model(image, text)
  probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)

In [None]:
clip_visual_opt = torch.jit.optimize_for_inference(torch.jit.trace(model.visual, image))
clip_visual_opt.save("/content/clip_visual.pt")

In [None]:
text_transformer = TextTransformer(model)
clip_text_opt = torch.jit.optimize_for_inference(torch.jit.trace(text_transformer, text))
clip_text_opt.save("/content/clip_text.pt")