<a href="https://colab.research.google.com/github/rb58853/images_RIS-ML-Conv-NLP/blob/main/end_model/caption.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install transformers

## Import librarys

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

### Load Image

In [None]:
img_url = '/content/2.jpg'
raw_image = Image.open(img_url).convert("RGB")
image = cv2.imread(img_url)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load all Models

## Segment Anything Model

In [None]:
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth


import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)

#### Create image from mask

In [None]:
from PIL import Image

def mask_image(mask, raw_image):
    weigth, heigth = raw_image.size
    new_image = Image.new('RGBA', (weigth, heigth), (0, 0, 0, 0))

    original_pixles = raw_image.load()
    pixels = new_image.load()

    for i in range (heigth):
        for j in range (weigth):
            if mask[i,j]:
                pixels[j, i] = original_pixles[j,i]
            else:
                pass
    return new_image

def bbox_image(bbox, image):
    x,y,w,h =  bbox[0],bbox[1],bbox[2],bbox[3]
    return image[y:y+h, x:x+w]

In [None]:
def all_areas_from_image(image, raw_image):
    masks = mask_generator_2.generate(image)
    # masks = mask_generator.generate(image)
    images_box= []
    images_mask= []
    for mask in masks:
        images_box.append(bbox_image(mask['bbox'],image))
        images_mask.append(mask_image(mask['segmentation'], raw_image))
    return {'box':images_box, 'mask':images_mask}

## BLIP

In [None]:
from transformers import BlipProcessor, BlipForConditionalGeneration

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda")

In [None]:
def blip (_image):
    inputs = blip_processor(_image, return_tensors="pt").to("cuda")
    out = blip_model.generate(**inputs)
    result = blip_processor.decode(out[0], skip_special_tokens=True)

    if result[:9] == "there is ":
        result = result[9:]

    return result

def all_captions(image, raw_image):
    # areas = all_areas_from_image(image, raw_image)['mask']
    areas = all_areas_from_image(image, raw_image)['box']
    origin = str(blip(raw_image))
    captions = [origin]
    for im in areas:
        captions.append(origin +" "+ str(blip(im)))
    return captions

## CLIP

In [None]:
from transformers import CLIPProcessor, CLIPModel

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_model = clip_model.to(device)

In [None]:
def select_caption(captions, image):
    inputs = clip_processor(text=captions, images=image, return_tensors="pt", padding=True)
    inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
    outputs = clip_model(**inputs)

    logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1)
    return {'caption':select_from_probs(probs, captions), 'probs': probs[0]}

def select_from_probs(probs, captions):
    max_prob = 0
    index = 0
    for i,prob in zip(range(len(probs[0])),probs[0]):
        if prob > max_prob:
            max_prob = prob
            index = i
    return captions[index]

def reduce_caption(caption, image):
    split = caption.split(' ')
    for word in split:
        temp = caption.split(' ')
        temp.remove(word)
        temp = ' '.join(temp)

        inputs = clip_processor(text=[temp, caption], images=image, return_tensors="pt", padding=True)
        inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
        outputs = clip_model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)

        if probs[0][0]> probs[0][1]:
            caption = temp


    return caption

# Run Model

In [None]:
def short_captions(probs,captions):
    for i in range(len(captions)):
        for j in range(i+1, len(captions)):
            if probs[j]>probs[i]:
                temp_p= probs[i]
                temp_c = captions[i]
                probs[i] = probs[j]
                captions[i] = captions[j]
                probs[j] = temp_p
                captions[j] = temp_c
    return {prob: caption for prob,caption in zip(probs,captions) }

In [None]:
captions = all_captions(image, raw_image)

In [None]:
result = select_caption(captions, raw_image)
probs = result['probs']
end_captions = short_captions(probs, captions)

print("\no_caption: "+str(result['caption']))
rduced_caption = reduce_caption(result['caption'], raw_image)
print("r_caption: "+str(rduced_caption), end= '\n\n')

if rduced_caption not in captions:
    captions.append(rduced_caption)

result = select_caption(captions, raw_image)
probs = result['probs']
end_captions = short_captions(probs, captions)

for key,value in zip(end_captions.keys(),end_captions.values()):
    print("{:.2f}".format(key.item() * 100) + "%: "+ str(value))



In [None]:
for im in all_areas_from_image(image, raw_image)['box']:
    plt.figure(figsize=(3,3))
    plt.imshow(im)
    plt.axis('off')
    plt.show()