## CLIP (Contrastive Language-Image Pre-Training) 

Reference:
- https://huggingface.co/docs/transformers/model_doc/clip
- https://github.com/mlfoundations/open_clip/blob/main/docs/Interacting_with_open_clip.ipynb
- https://github.com/openai/CLIP

Prerequisite
- Install scikit-image 
```
pip install scikit-image
```

In [None]:
# pip install scikit-image

In [None]:
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModel
import torch

from PIL import Image
import matplotlib.pyplot as plt
import requests

In [None]:
# in order to download models from huggingface, it is necessary to set the following proxy and ssl 
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# suwon
import os
os.environ['REQUESTS_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'
os.environ['HTTP_PROXY'] ='http://75.17.107.42:8080'
os.environ['HTTPS_PROXY'] ='http://75.17.107.42:8080'

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset_path = '/group-volume/sr_edu/AI-Application-Specialist-Vision-Dataset/'
hf_path = dataset_path + 'hf-models/'
clip_ckpt = "openai/clip-vit-base-patch32"

local_model = hf_path + clip_ckpt
model = CLIPModel.from_pretrained(local_model) 
processor = CLIPProcessor.from_pretrained(local_model)

In [None]:
# huggingface에서 직접 download시
#clip_ckpt = "openai/clip-vit-base-patch32"

#model = CLIPModel.from_pretrained(clip_ckpt) 
#processor = CLIPProcessor.from_pretrained(clip_ckpt)

In [None]:
model

## Zero-shot Image classification

In [None]:
from IPython.display import Image as DisplayImage
# from openai/clip: https://github.com/openai/CLIP?tab=readme-ov-file
DisplayImage(dataset_path + 'hf-assets/clip-zero-shot-2.png', width=600)

In [None]:
#ls ./images/

In [None]:
#image = Image.open("./images/cat.jpg")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
plt.imshow(image)
plt.axis('off')

In [None]:
texts = ["a photo of a cat", "a photo of a robot", "a photo of an apple", "a photo of cats"]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
    outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
probs, texts[probs.argmax().cpu().numpy()]

In [None]:
logits_per_image

### 더 정확한 text 또는 다른 유사한 text 표현을 더 추가해 실습해 보세요.

In [None]:
#texts = ["a photo of a cat", "a photo of a robot", "a photo of a turtle", "a photo of a cat and a turtle"]
texts = ["a photo of a cat", "a photo of a robot", "a photo of a turtle", "a photo of two cats", "a photo of three cats"]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
    outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
probs, texts[probs.argmax().cpu().numpy()]

In [None]:
#model.config
#outputs
outputs.keys()

### skimage에 있는 image와 text description으로 실습

In [None]:
#import os
import skimage
import IPython.display
#import matplotlib.pyplot as plt
#from PIL import Image
import numpy as np

from collections import OrderedDict

# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}

In [None]:
# load image and text descriptions
original_images = []
texts = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
    
    plt.subplot(2, 4, len(original_images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    texts.append(descriptions[name])
    original_images.append(image)
    
plt.tight_layout()

In [None]:
text_input = ["This is " + desc for desc in texts]
image_input = original_images
inputs = processor(text=text_input, images=image_input, return_tensors="pt", padding=True)
with torch.no_grad():
    outputs = model(**inputs)

text_features = outputs.text_embeds.float()  # n_image x emb_dim [8, 512]
image_features = outputs.image_embeds.float() # n_image x emb_dim [8, 512]

In [None]:
text_features.shape, image_features.shape

In [None]:
# image_features, text_features : almost unit vector
image_features /= image_features.norm(dim=-1, keepdim=True) # vector normalization to unit vector
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T # numpy matrix multiplications or dot product

In [None]:
count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)

### Zero-shot Image classification using cifar100 classes

In [None]:
from torchvision.datasets import CIFAR100, CIFAR10

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), download=True)
dataset = cifar100

In [None]:
text_input = [f"A photo of a {label}" for label in dataset.classes]

image_input = original_images

In [None]:
inputs = processor(text=text_input, images=image_input, return_tensors="pt", padding=True)
with torch.no_grad():
    outputs = model(**inputs)
    
text_features = outputs.text_embeds.float()
image_features = outputs.image_embeds.float()
text_features.shape, image_features.shape

In [None]:
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

In [None]:
plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    #plt.barh(y, top_probs[i].detach().numpy())
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    #plt.yticks(y, [dataset.classes[index] for index in top_labels[i].numpy()])
    plt.yticks(y, [text_input[index].lower().replace("a photo of ","") for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()

### (ToDo) text 표현을 구분하는 표현이 부족해서 분류 성능이 떨어지는 점을 개선해 봅시다.

classification 성능을 향상시킬 수 있도록 위 실험에서 text 표현을 추가해  보세요.

In [None]:
text_input.extend(["A photo of a written paper", "A photo of a cat", "a horse", "a horse icon"])
text_input

### [참고] model.get_image_features(), get_text_features() 사용하여 구현할 수도 있음

In [None]:
?model.get_image_features
?model.get_text_features

In [None]:
image_input = original_images

# extract image feature vectors
image_inputs = processor(images=image_input, return_tensors="pt")
with torch.no_grad():
    image_features = model.get_image_features(**image_inputs)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(hf_path + "openai/clip-vit-base-patch32")
text_input = [f"A photo of a {label}" for label in dataset.classes]

# extract text feature vectors
text_inputs = tokenizer(text_input, padding=True, return_tensors="pt")
with torch.no_grad():
    text_features = model.get_text_features(**text_inputs)

In [None]:
text_features.shape, image_features.shape