# CLIP

> Contrastive Language–Image Pre-training

In [None]:
#| default_exp image.clip

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

## HF

In [None]:
#| export
from PIL import Image
import requests
from pathlib import Path
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import torch
from tqdm import tqdm
import numpy as np

### Usage

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)

outputs = model(**inputs)
logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1)  # we can 
print(probs)

tensor([[0.9949, 0.0051]], grad_fn=<SoftmaxBackward0>)


In [None]:
def embed_image(images):
    if not isinstance(images, list): images = [images]
    inputs = processor(images=images, return_tensors="pt", padding=True)
    with torch.no_grad(): return model.get_image_features(**inputs)

def embed_text(text):
    inputs = processor(text=text, return_tensors="pt", padding=True)
    with torch.no_grad(): return model.get_text_features(**inputs)

def normalize(a): return a / a.norm(dim=-1, keepdim=True)

def cosine_sim(a, b): return normalize(a) @ normalize(b).T

def logits(a, b): return model.logit_scale.exp() * cosine_sim(a, b)

def probs(a, b): return logits(a, b).softmax(dim=0)

def classify(image, classes, template="a photo of a {}"):
    image_embs = embed_image(image)
    text_embs = embed_text([template.format(o) for o in classes])
    return probs(text_embs, image_embs)

def search(image_embs, query_embs):
    sims = cosine_sim(image_embs, query_embs).flatten()
    indices = sims.argsort(descending=True)
    return indices, sims[indices]

def thumbnail(image, scale=3):
    return image.resize(np.array(image.size)//scale)

In [None]:
paintings = load_dataset("huggan/few-shot-art-painting")

Repo card metadata block was not found. Setting CardData to empty.


In [None]:
all_image_embs_path = Path("paintings_embeddings.npy")
if all_image_embs_path.exists():
    all_image_embs = torch.tensor(np.load(all_image_embs_path))
else:
    all_image_embs = torch.cat([embed_image(row['image']) for row in tqdm(paintings['train'])])
    np.save(all_image_embs_path, np.array(all_image_embs))

100%|██████████| 1000/1000 [02:23<00:00,  6.96it/s]


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()