## MiniCLIP to Colab
A Colab notebook porting of the MiniCLIP demo (https://github.com/HendrikStrobelt/miniClip) from Hendrik Strobelt.

Import the missing dependencies.


In [None]:
!pip install torchray 

Install CLIP.

In [None]:
!pip install git+https://github.com/openai/CLIP.git

Add the MiniCLIP user defined functions.

In [None]:
from PIL import Image
import numpy as np
import torch
from matplotlib import cm

def min_max_norm(array):
    lim = [array.min(), array.max()]
    array = array - lim[0] 
    array.mul_(1 / (1.e-10+ (lim[1] - lim[0])))
    # array = torch.clamp(array, min=0, max=1)
    return array

def torch_to_rgba(img):
    img = min_max_norm(img)
    rgba_im = img.permute(1, 2, 0).cpu()
    if rgba_im.shape[2] == 3:
        rgba_im = torch.cat((rgba_im, torch.ones(*rgba_im.shape[:2], 1)), dim=2)
    assert rgba_im.shape[2] == 4
    return rgba_im

def numpy_to_image(img, size):
    """
    takes a [0..1] normalized rgba input and returns resized image as [0...255] rgba image
    """
    resized = Image.fromarray((img*255.).astype(np.uint8)).resize((size, size))
    return resized

def upscale_pytorch(img:np.array, size):
    torch_img = torch.from_numpy(img).unsqueeze(0).permute(0,3,1,2)
    print(torch_img)
    upsampler = torch.nn.Upsample(size=size)    
    return upsampler(torch_img)[0].permute(1,2,0).cpu().numpy()


def heatmap(image:torch.Tensor, heatmap: torch.Tensor, size=None, alpha=.6):
    if not size:
        size = image.shape[1]
    # print(heatmap)
    # print(min_max_norm(heatmap))

    img = torch_to_rgba(image).numpy() # [0...1] rgba numpy "image"
    hm = cm.hot(min_max_norm(heatmap).numpy()) # [0...1] rgba numpy "image"

    # print(hm.shape, hm)
 #

    img = np.array(numpy_to_image(img,size))
    hm = np.array(numpy_to_image(hm, size))
    # hm = upscale_pytorch(hm, size)
    # print (hm) 

    #return Image.fromarray((alpha * hm + (1-alpha)*img).astype(np.uint8))
    return Image.fromarray(hm)

Import other necessary dependencies.

In [None]:
import clip
from torchray.attribution.grad_cam import grad_cam

User defined function to get the CLIP ResNet50 model.

In [None]:
def get_model():
    return clip.load("RN50", device=device, jit=False)

User defined function to upload an image.

In [None]:
def upload_files():
  from google.colab import files
  uploaded = files.upload()
  for k, v in uploaded.items():
    open(k, 'wb').write(v)
  return list(uploaded.keys())

Set the device to use.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

Build the Form.

In [None]:
#@title Options

alpha = '0.5' #@param ["0.5", "0,7", "0.8"]
layer = 'layer4.2.relu' #@param ["layer4.2.relu"]

Upload an image.

In [None]:
uploaded_image_list = upload_files()

Enter some descriptive text.

In [None]:
#@title Enter some descriptive texts

textarea = 'a pizza; a beer' #@param {type:"string"}
prefix = 'an image of' #@param {type:"string"}

Read and preprocess the uploaded image.

In [None]:
image_raw = Image.open(uploaded_image_list[0])

In [None]:
model, preprocess = get_model()

In [None]:
image = preprocess(image_raw).unsqueeze(0).to(device)

Preprocess text.

In [None]:
prefix = prefix.strip()
if len(prefix) > 0:
        categories = [f"{prefix} {x.strip()}" for x in textarea.split(';')]
else:
        categories = [x.strip() for x in textarea.split(';')]
text = clip.tokenize(categories).to(device)

Calculate the saliency map.

In [None]:
with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        image_features_norm = image_features.norm(dim=-1, keepdim=True)
        image_features_new = image_features / image_features_norm
        text_features_norm = text_features.norm(dim=-1, keepdim=True)
        text_features_new = text_features / text_features_norm
        logit_scale = model.logit_scale.exp()
        logits_per_image = logit_scale * image_features_new @ text_features_new.t()
        probs = logits_per_image.softmax(dim=-1).cpu().numpy().tolist()

saliency = grad_cam(model.visual, image.type(model.dtype), image_features, saliency_layer=layer)

In [None]:
hm = heatmap(image[0], saliency[0][0,].detach().type(torch.float32).cpu(), alpha=alpha)

Collect the images.

In [None]:
collect_images = []
for i in range(len(categories)):
    # mutliply the normalized text embedding with image norm to get approx image embedding
    text_prediction = (text_features_new[[i]] * image_features_norm)
    saliency = grad_cam(model.visual, image.type(model.dtype), text_prediction, saliency_layer=layer)
    hm = heatmap(image[0], saliency[0][0,].detach().type(torch.float32).cpu(), alpha=alpha)
    collect_images.append(hm)
logits = logits_per_image.cpu().numpy().tolist()[0]

Show the Grad Cam for text embeddings.

In [None]:
text_embeddings = [f"{x} - {str(round(y, 3))}/{str(round(l, 2))}" for (x, y, l) in
                      zip(categories, probs[0], logits)]
for image_idx in (0, len(collect_images) - 1):
  display(collect_images[image_idx])
  print(text_embeddings[image_idx])


Show the original image and Grad Cam for image embedding.

In [None]:
display(Image.fromarray((torch_to_rgba(image[0]).numpy() * 255.).astype(np.uint8)), hm)