In [None]:
!pip install git+https://github.com/huggingface/transformers.git
!pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
import requests
from PIL import Image

# from Segment Anything demo
url = 'https://segment-anything.com/assets/gallery/GettyImages-1207721867.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')

# # from local device
# img_path = './demo.jpg'
# image = Image.open(img_path).convert('RGB')

width, height = image.size
display(image.resize((width // 3, height // 3)))

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator

# parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if 'cuda' in device else torch.float32
model_type = 'vit_h'
checkpoint = 'sam_vit_h_4b8939.pth'

# SAM initialization
model = sam_model_registry[model_type](checkpoint = checkpoint)
model.to(device)
predictor = SamPredictor(model)
mask_generator = SamAutomaticMaskGenerator(model)
predictor.set_image(np.array(image)) # load the image to predictor

model

In [None]:
input_point = [[1800, 950]] # A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels.
input_label = [1]           # A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
input_point = np.array(input_point)
input_label = np.array(input_label)
masks, scores, logits = predictor.predict(point_coords = input_point, point_labels = input_label)
masks = masks[0, ...]

display(Image.fromarray(masks).resize(((width // 3, height // 3))))

In [None]:
crop_mode = "wo_bg" # Optional['wo_bg', 'w_bg'], where w_bg and wo_bg refer to remain and discard background separately.

if crop_mode == "wo_bg":
    masked_image = image * masks[:,:,np.newaxis] + (1 - masks[:,:,np.newaxis]) * 255
    masked_image = np.uint8(masked_image)
else:
    masked_image = np.array(image)
masked_image = Image.fromarray(masked_image)

display(masked_image.resize((width // 3, height // 3)))

In [None]:
def boundary(inputs):
    
    col = inputs.shape[1]
    inputs = inputs.reshape(-1)
    lens = len(inputs)
    start = np.argmax(inputs)
    end = lens - 1 - np.argmax(np.flip(inputs))
    top = start // col
    bottom = end // col
    
    return top, bottom

def seg_to_box(seg_mask, size):
    
    top, bottom = boundary(seg_mask)
    left, right = boundary(seg_mask.T)
    left, top, right, bottom = left / size, top / size, right / size, bottom / size # we normalize the size of boundary to 0 ~ 1

    return [left, top, right, bottom]

size = max(masks.shape[0], masks.shape[1])
left, top, right, bottom = seg_to_box(masks, size) # calculating the position of the top-left and bottom-right corners in the image
print(left, top, right, bottom)

image_crop = masked_image.crop((left * size, top * size, right * size, bottom * size)) # crop the image
display(image_crop)

In [None]:
!pip install accelerate bitsandbytes

In [None]:
from transformers import AutoProcessor, Blip2ForConditionalGeneration

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
captioning_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map = "sequential", load_in_8bit = True)

In [None]:
inputs = processor(image_crop, return_tensors = "pt").to(device, torch_dtype)
out = captioning_model.generate(**inputs, max_new_tokens = 50)
captions = processor.decode(out[0], skip_special_tokens = True).strip()

captions

In [None]:
text_prompt = 'Question: What does the image show? Answer:'

inputs = processor(image_crop, text = text_prompt, return_tensors = "pt").to(device, torch_dtype)
out = captioning_model.generate(**inputs, max_new_tokens = 50)
captions = processor.decode(out[0], skip_special_tokens = True).strip()

captions