In [25]:
import cv2
import torch
import sys, os
import shutil
from tqdm import tqdm
import os.path as osp
import numpy as np
from groundingdino.util.inference import load_model, load_image, predict, annotate
from segment_anything import sam_model_registry, SamPredictor
import clip
from torch.nn.functional import cosine_similarity
from torchvision.ops import box_convert, box_iou
import matplotlib.pyplot as plt
from openai import OpenAI
from dotenv import load_dotenv
import base64
import requests
import os 

sys.path.append("segment-anything")


load_dotenv()

# Set model name and API key
MODEL = 'gpt-4o'
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)


IMAGE_PATH = "GroundingDINO/weights/inpaint_demo.jpg"
TEXT_PROMPT_CAPTURE = "object on the hand"
TEXT_PROMPT_RECOGNIZE = "object"
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25



In [26]:
def encode_image(image_path):
    with open(image_path, 'rb') as f:
        return base64.b64encode(f.read()).decode('utf-8')

def get_calorie(base64_image):
    response = client.chat.completions.create(
    model= MODEL,
    messages=[
        {
        "role": "user",
        "content": [
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"}
            },{
                "type": "text",
                "text": "Send back name and estimated total calorie of this item using certain form below:\nName:\nCalorie: (only value + kcal)"
            }
        ]},
    ],
    temperature=1,
    max_tokens=256,
    top_p=1,
    frequency_penalty=0,
    presence_penalty=0
    )
    
    # Name and Calorie of the food item by API
    obj_info_byOPENAI = response.choices[0].message.content.split('\n')
    if len(obj_info_byOPENAI) != 2:
        return None, None
    else:
        obj_name = obj_info_byOPENAI[0]
        obj_calorie = obj_info_byOPENAI[1]
        return obj_name, obj_calorie


def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
    order = torch.argsort(-scores)
    indices = torch.arange(bboxes.shape[0])
    keep = torch.ones_like(indices, dtype=torch.bool)
    for i in indices:
        if keep[i]:
            bbox = bboxes[order[i]]
            iou = box_iou(bbox[None,...],(bboxes[order[i + 1:]]) * keep[i + 1:][...,None])
            overlapped = torch.nonzero(iou > iou_threshold)
            keep[overlapped + i + 1] = 0
    return order[keep]

def getJetColorRGB(v, vmin, vmax):
    c = np.zeros((3))
    if (v < vmin):
        v = vmin
    if (v > vmax):
        v = vmax
    dv = vmax - vmin
    if (v < (vmin + 0.125 * dv)): 
        c[0] = 256 * (0.5 + (v * 4)) #B: 0.5 ~ 1
    elif (v < (vmin + 0.375 * dv)):
        c[0] = 255
        c[1] = 256 * (v - 0.125) * 4 #G: 0 ~ 1
    elif (v < (vmin + 0.625 * dv)):
        c[0] = 256 * (-4 * v + 2.5)  #B: 1 ~ 0
        c[1] = 255
        c[2] = 256 * (4 * (v - 0.375)) #R: 0 ~ 1
    elif (v < (vmin + 0.875 * dv)):
        c[1] = 256 * (-4 * v + 3.5)  #G: 1 ~ 0
        c[2] = 255
    else:
        c[2] = 256 * (-4 * v + 4.5) #R: 1 ~ 0.5                      
    return c

def ground_dino_predict(model, img_path, text_prompt, box_threshold=0.35, text_threshold=0.25, topK=10):
    image_source, image = load_image(img_path)

    boxes, logits, phrases = predict(
        model=model,
        image=image,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )

    # nms
    keep = nms(boxes, logits, 0.5)
    boxes = boxes[keep]
    logits = logits[keep]

    print("Predicted boxes:", boxes.shape[0])

    return boxes, logits

def sam_predict(predictor, img_path, boxes):
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_rgb = image.copy()
    predictor.set_image(image)

    # get from object detector
    h, w, _ = image.shape
    boxes = boxes * torch.Tensor([w, h, w, h], device=boxes.device)
    xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy")

    input_boxes = xyxy.to(predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )

    return masks, xyxy, img_rgb

def ground_dino_sam_predict(model, predictor, img_path, text_prompt, box_threshold=0.35, text_threshold=0.25):
    boxes, logits = ground_dino_predict(model, img_path, text_prompt, box_threshold, text_threshold)
    masks, boxes, img_rgb = sam_predict(predictor, img_path, boxes)
    return masks, boxes, img_rgb

def load_model_and_predict():
    #Use GroundingDino to detect items.
    print("Loading GroundingDINO model...")
    model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "GroundingDINO/weights/groundingdino_swint_ogc.pth")

    # Use SAM to generate the mask.
    print("Loading Segment Anything model...")
    sam_checkpoint = "segment-anything/sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cpu"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device) 
    predictor = SamPredictor(sam)

    print("Loading clip ViT-B/32 model...")
    extractor, _ = clip.load("ViT-B/32", device, jit=False)

    return model, predictor, extractor

def extract_saved_obj_features(model, predictor, extractor):
    print("Entering extract_saved_obj_features")
    input_path = "./saved_obj_img"
    output_path = "./extract_saved_obj_feature"
    if not osp.exists(output_path):
        os.mkdir(output_path)
    else:
        # remove all files
        shutil.rmtree(output_path)
        os.mkdir(output_path)

    obj_list = [oi for oi in os.listdir(input_path) if not "DS_Store" in oi]
    obj_list.sort()
    
    obj_features = []

    for obj_dir in tqdm(obj_list):
        obj_path = osp.join(input_path, obj_dir)
        features_all = []
        if not osp.exists(osp.join(output_path, obj_dir)):
            os.mkdir(osp.join(output_path, obj_dir))
        for obj in [oi for oi in os.listdir(obj_path) if not "DS_Store" in oi]:
            img_path = osp.join(obj_path, obj)

            # ground dino and sam inference
            # mask 
            masks, boxes, img_rgb = ground_dino_sam_predict(model, predictor, img_path, TEXT_PROMPT_CAPTURE)
            x1, y1, x2, y2 = boxes[0].cpu().numpy()
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            
            # mask
            masked_img = img_rgb * masks[0].cpu().numpy().transpose(1, 2, 0)
            cv2.imwrite(osp.join(output_path, obj_dir, "mask_"+obj), cv2.cvtColor(masked_img, cv2.COLOR_RGB2BGR))

            # crop and pad to center
            masked_img = masked_img[y1:y2, x1:x2, :]
            
            h, w, _ = masked_img.shape
            size = max(h, w)
            img_pad = np.zeros((size, size, 3)).astype(np.uint8)
            img_pad[size//2-h//2:size//2+(h-h//2), size//2-w//2:size//2+(w-w//2)] = masked_img
            img_pad = cv2.resize(img_pad, (224, 224))
            
            # covered img pad
            img_pad_rtop = img_pad.copy()
            img_pad_ltop = img_pad.copy()
            img_pad_rdown = img_pad.copy()
            img_pad_ldown = img_pad.copy()


            img_pad_rtop[0:90, 134:224, :] = (0, 0, 0)
            img_pad_ltop[0:90, 0:90, :] = (0, 0, 0)
            img_pad_rdown[134:224, 134:224, :] = (0, 0, 0)
            img_pad_ldown[134:224, 0:90, :] = (0, 0, 0)
            
            img_pad_covered = [img_pad, img_pad_rtop, img_pad_ltop, img_pad_rdown, img_pad_ldown]
                
    
            if not osp.exists(osp.join(output_path, obj_dir, "covered")):
                os.mkdir(osp.join(output_path, obj_dir, "covered"))
            output_covered_path = osp.join(output_path, obj_dir, "covered")

            cv2.imwrite(osp.join(output_path, obj_dir, "mask_crop_"+obj), cv2.cvtColor(img_pad, cv2.COLOR_RGB2BGR))
            cv2.imwrite(osp.join(output_covered_path, "mask_crop_rtop_"+obj), cv2.cvtColor(img_pad_rtop, cv2.COLOR_RGB2BGR))
            cv2.imwrite(osp.join(output_covered_path, "mask_crop_ltop_"+obj), cv2.cvtColor(img_pad_ltop, cv2.COLOR_RGB2BGR))
            cv2.imwrite(osp.join(output_covered_path, "mask_crop_rdown_"+obj), cv2.cvtColor(img_pad_rdown, cv2.COLOR_RGB2BGR))
            cv2.imwrite(osp.join(output_covered_path, "mask_crop_ldown_"+obj), cv2.cvtColor(img_pad_ldown, cv2.COLOR_RGB2BGR))


            # normalize
            for img in img_pad_covered:
                image = img.astype(np.float32) / 255.
                image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).cpu()
                mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(image.device)
                std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(image.device)
                mean = mean.view(1, -1, 1, 1)
                std = std.view(1, -1, 1, 1)

                image = (image - mean) / std

                features = extractor.encode_image(image)
                features_all.append(features)
      
        features_all = torch.cat(features_all, dim=0)
        obj_features.append(features_all.unsqueeze(0))
    
    obj_features = torch.cat(obj_features, dim=0)
    obj_features = {
        "features": obj_features,
        "obj_list": obj_list
    }
    torch.save(obj_features, osp.join(output_path, "obj_features1.pt"))
    print("saved obj_features.pt!!")

    return obj_features

def recognize_pipeline(model, predictor, extractor, obj_features, recognize_img_path, box_threshold=0.35, text_threshold=0.25, idx=0):

    obj_list = obj_features["obj_list"]
    obj_features = obj_features["features"]
    
    
    if not osp.exists(f"recognized_results/{idx}"):
        os.mkdir(f"recognized_results/{idx}")
    else:
        # remove all files
        shutil.rmtree(f"recognized_results/{idx}")
        os.mkdir(f"recognized_results/{idx}")

    # ground dino and sam inference
    masks, boxes, img_rgb = ground_dino_sam_predict(model, predictor, recognize_img_path, TEXT_PROMPT_RECOGNIZE, box_threshold, text_threshold)
    img_rgb_copy = img_rgb.copy()

    for obj_idx, (boxi, maski) in enumerate(zip(boxes, masks)):
        x1, y1, x2, y2 = boxi.cpu().numpy()
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

        area = (x2-x1) * (y2-y1)
        if area > 1000*1000:
            continue

        # mask
        masked_img = img_rgb * maski.cpu().numpy().transpose(1, 2, 0)
        cv2.imwrite(f"./recognized_results/{idx}/mask_{obj_idx}.jpg", cv2.cvtColor(masked_img, cv2.COLOR_RGB2BGR))

        # crop and pad to center
        masked_img = masked_img[y1:y2, x1:x2, :]
        h, w, _ = masked_img.shape
        size = max(h, w)
        img_pad = np.zeros((size, size, 3)).astype(np.uint8)
        img_pad[size//2-h//2:size//2+(h-h//2), size//2-w//2:size//2+(w-w//2)] = masked_img
        img_pad = cv2.resize(img_pad, (224, 224))
        cv2.imwrite(f"./recognized_results/{idx}/mask_crop_{obj_idx}.jpg", cv2.cvtColor(img_pad, cv2.COLOR_RGB2BGR))

        # normalize
        image = img_pad.astype(np.float32) / 255.
        image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).cpu()
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(image.device)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(image.device)
        mean = mean.view(1, -1, 1, 1)
        std = std.view(1, -1, 1, 1)

        image = (image - mean) / std

        # extract features
        features = extractor.encode_image(image)  # (1, 512)

        similarity = cosine_similarity(features, obj_features, dim=-1)
        sim_order = torch.argmax(similarity, dim=-1)

        res_text = []
        max_score_idx = 0
        max_score = 0
        for i in range(len(sim_order)):
            sim_t = similarity[i][sim_order[i]]
            if sim_t > max_score:
                max_score = sim_t
                max_score_idx = i
            res_text.append(f"{obj_list[i]}: {sim_t:.3f}")
        
        # write to image
        y1_t = y1
        for rti_idx, rti in enumerate(res_text):
            # # print all
            # if not rti_idx == max_score_idx:
            #     cv2.putText(img_rgb_copy, rti, (x1+6, y1_t+40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
            # else:
            #     cv2.putText(img_rgb_copy, rti, (x1+6, y1_t+40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
            # y1_t += 40

            # print max score only
            if rti_idx == max_score_idx:
                cv2.putText(img_rgb_copy, rti, (x1+6, y1_t+40), cv2.FONT_HERSHEY_SIMPLEX, 1.8, (0, 255, 0), 2)
                # call API
                CROP_IMAGE_PATH = f'./recognized_results/{idx}/mask_crop_{obj_idx}.jpg'
                base64_image = encode_image(CROP_IMAGE_PATH)
                _, obj_calorie = get_calorie(base64_image)
                cv2.putText(img_rgb_copy, obj_calorie, (x1+6, y1_t+100), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2)
        cv2.rectangle(img_rgb_copy, (x1, y1), (x2, y2), (0, 0, 255), 2)
    
    # cv2.imwrite(f"./recognized_results/{idx}/res_{box_threshold}_{text_threshold}.jpg", cv2.cvtColor(img_rgb_copy, cv2.COLOR_RGB2BGR))
    cv2.imwrite(f"./recognized_results/{idx}/res_{box_threshold}_{text_threshold}_withCalorie.jpg", cv2.cvtColor(img_rgb_copy, cv2.COLOR_RGB2BGR))


In [28]:
model, predictor, extractor = load_model_and_predict()

Loading GroundingDINO model...




final text_encoder_type: bert-base-uncased




Loading Segment Anything model...




Loading clip ViT-B/32 model...


In [None]:
obj_features = extract_saved_obj_features(model, predictor, extractor)
# obj_features = torch.load("./extract_saved_obj_feature/obj_features.pt")


In [None]:
input_path = "./saved_obj_img"
output_path = "./extract_saved_obj_feature"
if not osp.exists(output_path):
    os.mkdir(output_path)
else:
    # remove all files
    shutil.rmtree(output_path)
    os.mkdir(output_path)

obj_list = [oi for oi in os.listdir(input_path) if not "DS_Store" in oi]
obj_list.sort()
print(obj_list)

In [None]:
# obj_features = torch.load("./extract_saved_obj_feature/obj_features.pt")
obj_features1 = torch.load("./extract_saved_obj_feature/obj_features1.pt")
obj_features2 = torch.load("./extract_saved_obj_feature/obj_features2.pt")
print(obj_features1['features'].size())
print(obj_features2['features'].size())
obj_features = {}
obj_features['features'] = torch.cat((obj_features1['features'], obj_features2['features']), dim=0)
obj_features['obj_list'] = obj_features1['obj_list'] + obj_features2['obj_list']
print("Size of features tensor in obj_features:", obj_features1['features'].size())
print("Size of features tensor in obj_features:", obj_features2['features'].size())
print(obj_features['features'].size())

torch.save(obj_features, "./extract_saved_obj_feature/obj_features.pt")

In [17]:
obj_features = torch.load("./extract_saved_obj_feature/obj_features.pt")
# for i in range(8, 20):
#     recognize_pipeline(model, predictor, extractor, obj_features, f"./test_img/reg_all_{i}.jpg", 0.15, 0.15, idx=i)
i = 19
recognize_pipeline(model, predictor, extractor, obj_features, f"./test_img/reg_all_{i}.jpg", 0.15, 0.15, idx=i)


Predicted boxes: 14
