In [124]:
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as T
import clip
# from transformers import CLIPProcessor, CLIPModel
import os
import gc
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from sentence_transformers import SentenceTransformer, util

In [125]:
cuda_device_id = 3
torch.cuda.set_device(3)
torch.cuda.get_device_name(3)

'Tesla P40'

In [126]:
torch.cuda.empty_cache()
gc.collect()

0

In [127]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [128]:
# model = torch.load('/home/pyt_user/pp/pytorch/clip/checkpoints/ckpt_e0.pth')
# model = torch.load('/mnt/nis_lab_research/data/clip_data/pth/far_shah_b1-b5_b8_train_neg_vitb32_ep30/model_final.pth')
model = clip.load('RN50x64')[0]
device = torch.device(f"cuda:{cuda_device_id}")
# device = "cpu"
model = model.to(device)
model.eval()
preprocess = clip.load("RN50x64", device=device, jit=False)[1]
# processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')


In [129]:
model_sbert = SentenceTransformer('all-MiniLM-L6-v2')
model_sbert = model_sbert.to(device)

In [130]:
def get_top_5 (similarity):
    top5 = []
    values, indices = similarity[0].topk(5)  
    for value, index in zip(values, indices):
        top5.append([labels[index], 100 * value.item()])
        
    return top5

In [131]:
def run_clip_inference(image, context, labels):
    
    image_tensor = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    context_tensor = clip.tokenize(context[:77]).to(device)
    labels_tensor = torch.cat([clip.tokenize(txt.lower()) for txt in labels]).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        context_features = model.encode_text(context_tensor)
        label_features = model.encode_text(labels_tensor)        
          
    # Get images, context, and label features
    image_features /= image_features.norm(dim=-1, keepdim=True)
    context_features /= image_features.norm(dim=-1, keepdim=True)
    label_features /= label_features.norm(dim=-1, keepdim=True)

    # Calculate similarity score for image vs label
    similarity_img = (100.0 * image_features @ label_features.T).softmax(dim=-1)
    
    # Calculate similarity score for image vs label
    similarity_cont = (100.0 * context_features @ label_features.T).softmax(dim=-1)
    
    similarity_comb = similarity_img + similarity_cont
    
    return [similarity_img, similarity_cont, similarity_comb]



In [132]:
def run_sbert_inference (context, labels):
    context_embedding = model_sbert.encode(context)
    label_embeddings = model_sbert.encode(labels)
    similarity = util.pytorch_cos_sim(context_embedding, label_embeddings)
    return similarity

In [133]:
in_dir = "/mnt/nis_lab_research/data/clip_data/test/test1"
res_out_path = "/home/pyt_user/pp/pytorch/clip/res/test1_res_clip_rn50x64.txt"

In [134]:
labels = sorted(os.listdir(in_dir))
    
for label in labels:
    
    sd_path = os.path.join(in_dir, label)
    img_fn_list = sorted([fn for fn in os.listdir(sd_path) if fn.endswith('.png')])
    txt_fn_list = sorted([fn for fn in os.listdir(sd_path) if fn.endswith('.txt')])
    
    clip_ctr = 0
    sbert_ctr = 0
    comb_ctr = 0

    for i in range(len(img_fn_list)):
        
        image_path = os.path.join(sd_path, img_fn_list[i])
        txt_path = os.path.join(sd_path, txt_fn_list[i])
        
        with open(txt_path, "r") as f:
            context = f.read().lower()
        
        # Run inference
        clip_similarity = run_clip_inference(img, context, labels)
        sbert_similarity = run_sbert_inference(context, labels)
        if context != "":
            comb_similarity = clip_similarity[0].to("cpu") + sbert_similarity.to("cpu")
        else:
            comb_similarity = clip_similarity[0].to("cpu")
        
        clip_sim_lab = get_top_5(clip_similarity[0])[0][0]
        sbert_sim_lab = get_top_5(sbert_similarity)[0][0]
        comb_sim_lab = get_top_5(comb_similarity)[0][0]
        
        if label == clip_sim_lab:
            clip_ctr += 1
            
        if label == sbert_sim_lab:
            sbert_ctr += 1
            
        if label == comb_sim_lab:
            comb_ctr += 1
            
    with open(res_out_path, 'a+') as file:
        file.write(f"{label}\n")
        file.write(f"clip via img: {clip_ctr} {len(img_fn_list)} {clip_ctr / len(img_fn_list)}\n")
        file.write(f"sbert: {sbert_ctr} {len(img_fn_list)} {sbert_ctr / len(img_fn_list)}\n")
        file.write(f"comb: {comb_ctr} {len(img_fn_list)} {comb_ctr / len(img_fn_list)}\n\n")

    # print(label, "processing")
    # print("clip via img:", clip_ctr, len(img_fn_list), clip_ctr/len(img_fn_list))
    # print("sbert:", sbert_ctr, len(img_fn_list), sbert_ctr/len(img_fn_list))
    # print("comb:", comb_ctr, len(img_fn_list), comb_ctr/len(img_fn_list))
        
    # print()
    # print("##################################################")
    # print()
