In [1]:
"""Model and preprocess function loading for grounded SAM experiment.
"""
import warnings
warnings.simplefilter("ignore")
import sys
sys.path.extend(["../", "./"])

import torch
import torchvision
from PIL import Image
import cv2
import open_clip
from segment_anything.utils.transforms import ResizeLongestSide
from segment_anything import sam_model_registry, build_sam, SamPredictor
from groundingdino.util.misc import nested_tensor_from_tensor_list

from models.grounded_sam import *
from models.GroundingDINO.groundingdino.util.inference import load_image
from linear_probe import LinearProbe

from models.GroundingDINO.groundingdino.models.GroundingDINO.bertwarper import (
    generate_masks_with_special_tokens_and_transfer_map_nocate
)

In [2]:
def load_model(device="cpu", predictor=False):
    """Load image encoders, text encoders, and linear probes.
    
    Load Grounding Dino - model(image encoder, text encoder), linear probe for image embedding and text embedding.
         SAM - model(image encoder), linear probe for image embedding
         Biomed CLIP - model(image encoder, text encoder), tokenizer for text, preprocess for image
    """
    # Load Grounding Dino
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
    ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
    groundingdino = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device=device) # groundingdino.backbone, groundingdino.bert, groundingdino.tokenizer
    groundingdino.to(device)
    
    # Load Grounded SAM
    sam_checkpoint = './ckpts/sam_vit_l_0b3195.pth'
    sam = sam_model_registry["vit_l"](checkpoint=sam_checkpoint)
    sam.to(device)
    sam_predictor = SamPredictor(sam)

    # Load Biomed CLIP
    biomedclip, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    biomedclip.to(device)

    # Load linear probe for Grounding Dino image embedding
    groundingdino_input_dims = [
        [1, 256, 28, 28],
        [1, 512, 14, 14],
        [1, 1024, 7, 7],
    ]
    groundingdino_img_linear = LinearProbe(
        groundingdino_input_dims,
        512,
        device,
    )

    # Load linear probe for Grounding Dino text embedding
    groundingdino_txt_dims = [
        [1, 195, 256]
    ]
    groundingdino_txt_linear = LinearProbe(
        groundingdino_txt_dims,
        512,
        device,
    )
    
    # Load linear probe for SAM image embedding
    sam_input_dims = [
        [1, 256, 64, 64]
    ]
    sam_linear = LinearProbe(
        sam_input_dims, 
        512,
        device,
    )
    if predictor:
        return groundingdino, sam_predictor, biomedclip, tokenizer, preprocess_train, groundingdino_img_linear, groundingdino_txt_linear, sam_linear

    return groundingdino, sam, biomedclip, tokenizer, preprocess_train, groundingdino_img_linear, groundingdino_txt_linear, sam_linear


def preprocess_sam(sam, image_path, device="cpu"):
    """Preprocess image for SAM."""
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((20, 20)),
        torchvision.transforms.ToTensor()
    ])
    input_image = Image.open(image_path) 
    input_image_torch = transform(input_image).to(device)

    x = input_image_torch
    pixel_mean = [123.675, 116.28, 103.53]
    pixel_std = [58.395, 57.12, 57.375]
    x = (x - torch.Tensor(pixel_mean).view(-1, 1, 1).to(device)) / torch.Tensor(pixel_std).view(-1, 1, 1).to(device)
    return x[None, :, :, :]


def preprocess_biomedclip(preprocess, tokenizer, image_path, text, device="cpu"):
    """Preprocess image and text for Biomed CLIP."""
    bmc_img = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    texts = tokenizer(text, context_length=256).to(device)
    return bmc_img, texts


def preprocess_groundingdino_img(image_path, device="cpu"):
    """Preprocess image for Grounding Dino."""
    _, image = load_image(image_path)
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
    ])
    image = transform(image)
    image = nested_tensor_from_tensor_list([image]).to(device)
    return image

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model
groundingdino, sam, biomedclip, tokenizer, preprocess_train, groundingdino_img_linear, groundingdino_txt_linear, sam_linear = load_model(device)
sam = sam
biomedclip = biomedclip
tokenizer = tokenizer
preprocess_train = preprocess_train
groundingdino_img_linear = groundingdino_img_linear
groundingdino_txt_linear = groundingdino_txt_linear
sam_linear = sam_linear

# Load preprocess
img_path = "./toy_data/chest_x_ray.jpeg"
text = "This is a image a 2 lungs."

sam_img = preprocess_sam(sam, img_path, device)
print("Sam image embedding shape:", sam_img.shape)

bmc_img, bmc_txt = preprocess_biomedclip(preprocess_train, tokenizer, img_path, text, device)
print("Biomed CLIP image embedding shape:", bmc_img.shape)
print("Biomed CLIP text embedding shape:", bmc_txt.shape)

groundingdino_img = preprocess_groundingdino_img(img_path, device)
print("Grounding Dino image embedding shape:", groundingdino_img.shape)

        
# Generate embedding
# SAM image embedding
print("SAM image shape:", sam_img.shape)
sam_img_embedding = sam.image_encoder(sam_img)
print("SAM image embedding shape:", sam_img_embedding.shape)
sam_img_embedding = sam_linear(sam_img_embedding)
print("SAM image embedding shape:", sam_img_embedding.shape)

# Biomed CLIP image + text embeddings
bmc_img_embedding = biomedclip.visual(bmc_img)
bmc_txt_embedding = biomedclip.encode_text(bmc_txt)
print("Biomed CLIP image embedding shape:", bmc_img_embedding.shape)
print("Biomed CLIP text embedding shape:", bmc_txt_embedding.shape)

# Grounding Dino image embedding
backbone_output, _ = groundingdino.backbone(groundingdino_img)
groundingdino_img_embedding = []
for emb in backbone_output:
    groundingdino_img_embedding.append(emb.tensors.to(device))
print("Grounding Dino image embedding shape:", groundingdino_img_embedding[0].shape)

# Grounding Dino text embedding
tokenized = groundingdino.tokenizer(text, padding="max_length", max_length=195, return_tensors="pt")
for key, value in tokenized.items():
    tokenized[key] = value.to(device)

text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map_nocate(
    tokenized, groundingdino.specical_tokens, groundingdino.tokenizer
)

if text_self_attention_masks.shape[1] > groundingdino.max_text_len:
    text_self_attention_masks = text_self_attention_masks[
        :, : groundingdino.max_text_len, : groundingdino.max_text_len
    ]
    position_ids = position_ids[:, : groundingdino.max_text_len]
    tokenized["input_ids"] = tokenized["input_ids"][:, : groundingdino.max_text_len]
    tokenized["attention_mask"] = tokenized["attention_mask"][:, : groundingdino.max_text_len]
    tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : groundingdino.max_text_len]

if groundingdino.sub_sentence_present:
    tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
    tokenized_for_encoder["attention_mask"] = text_self_attention_masks
    tokenized_for_encoder["position_ids"] = position_ids
else:
    tokenized_for_encoder = tokenized

bert_output = groundingdino.bert(**tokenized_for_encoder)
groundingdino_txt_embedding = groundingdino.feat_map(bert_output["last_hidden_state"]).to(device)
print("Grounding Dino text embedding shape:", groundingdino_txt_embedding.shape)


groundingdino_img_embedding = groundingdino_img_linear(groundingdino_img_embedding)
groundingdino_txt_embedding = groundingdino_txt_linear(groundingdino_txt_embedding)
print("Grounding Dino image embedding shape:", groundingdino_img_embedding.shape)
print("Grounding Dino text embedding shape:", groundingdino_txt_embedding.shape)

final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded from /home/mam0364/.cache/huggingface/hub/models--ShilongLiu--GroundingDINO/snapshots/a94c9b567a2a374598f05c584e96798a170c56fb/groundingdino_swinb_cogcoor.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight'])
