In [1]:
import open_clip

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')

open_clip_pytorch_model.bin:   0%|          | 0.00/599M [00:00<?, ?B/s]

open_clip_config.json:   0%|          | 0.00/469 [00:00<?, ?B/s]

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

open_clip_pytorch_model.bin:   0%|          | 0.00/784M [00:00<?, ?B/s]

open_clip_config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/225k [00:00<?, ?B/s]

## Single image BioCLIP prediction

In [5]:
import torch
from urllib.request import urlopen
from PIL import Image

template = 'this is a photo of '
labels = [
    'adenocarcinoma histopathology',
    'brain MRI',
    'covid line chart',
    'squamous cell carcinoma histopathology',
    'immunohistochemistry histopathology',
    'bone X-ray',
    'chest X-ray',
    'pie chart',
    'hematoxylin and eosin histopathology'
]

dataset_url = 'https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/'
test_imgs = ['brain_MRI.jpg']

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()

context_length = 256

images = torch.stack([preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]).to(device)
texts = tokenizer([template + l for l in labels], context_length=context_length).to(device)
with torch.no_grad():
    image_features, text_features, logit_scale = model(images, texts)

    logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
    sorted_indices = torch.argsort(logits, dim=-1, descending=True)

    logits = logits.cpu().numpy()
    sorted_indices = sorted_indices.cpu().numpy()

top_k = -1

for i, img in enumerate(test_imgs):
    pred = labels[sorted_indices[i][0]]

    top_k = len(labels) if top_k == -1 else top_k
    print(img.split('/')[-1] + ':')
    for j in range(top_k):
        jth_index = sorted_indices[i][j]
        print(f'{labels[jth_index]}: {logits[i][jth_index]}')
    print('\n')

brain_MRI.jpg:
brain MRI: 0.9999922513961792
hematoxylin and eosin histopathology: 5.9478638831933495e-06
immunohistochemistry histopathology: 1.6712749584257836e-06
pie chart: 1.055264178262405e-07
bone X-ray: 3.7441971301177546e-08
chest X-ray: 4.858768054560869e-09
adenocarcinoma histopathology: 1.9369059689466894e-09
squamous cell carcinoma histopathology: 2.331514703524107e-10
covid line chart: 3.6202614257102583e-12




## Apply GEM

In [7]:
pip install gem_torch

[0mCollecting gem_torch
  Obtaining dependency information for gem_torch from https://files.pythonhosted.org/packages/4c/39/f362a75f13104011ce460fa553b79395edd6fdadedfd7c721454ba71a789/gem_torch-1.0.1-py3-none-any.whl.metadata
  Downloading gem_torch-1.0.1-py3-none-any.whl.metadata (11 kB)
Collecting einops (from gem_torch)
  Obtaining dependency information for einops from https://files.pythonhosted.org/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl.metadata
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting opencv-python (from gem_torch)
  Obtaining dependency information for opencv-python from https://files.pythonhosted.org/packages/d9/64/7fdfb9386511cd6805451e012c537073a79a958a58795c4e602e538c388c/opencv_python-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Using cached opencv_python-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloa

In [8]:
import gem
gem.available_models()

'ViT-B/32            : openai\nViT-B/32            : laion400m_e31\nViT-B/32            : laion400m_e32\nViT-B/32            : laion2b_e16\nViT-B/32            : laion2b_s34b_b79k\nViT-B/32-quickgelu  : metaclip_400m\nViT-B/32-quickgelu  : metaclip_fullcc\nViT-B/16            : openai\nViT-B/16            : laion400m_e31\nViT-B/16            : laion400m_e32\nViT-B/16            : laion2b_s34b_b88k\nViT-B/16-quickgelu  : metaclip_400m\nViT-B/16-quickgelu  : metaclip_fullcc\nViT-B/16-plus-240   : laion400m_e31\nViT-B/16-plus-240   : laion400m_e32\nViT-L/14            : openai\nViT-L/14            : laion400m_e31\nViT-L/14            : laion400m_e32\nViT-L/14            : laion2b_s32b_b82k\nViT-L/14-quickgelu  : metaclip_400m\nViT-L/14-quickgelu  : metaclip_fullcc\nViT-L/14-336        : openai\n'

In [45]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

In [48]:
model

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

In [49]:
model.visual.

TimmModel(
  (trunk): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
  

In [51]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip.transformer import VisionTransformer

from gem.gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward


class GEMWrapper(nn.Module):
    def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None):
        super(GEMWrapper, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.depth = depth
        self.ss_attn_iter = ss_attn_iter
        self.ss_attn_temp = ss_attn_temp
        self.patch_size = 16#self.model.visual.patch_size[0]
        self.apply_gem()

    def apply_gem(self):
        for i in range(1, self.depth):
            # Extract info from the original ViT
            num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads
            dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads)
            qkv_bias = True
            # Init the self-self attention layer
            ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
                                        ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp)
            # Copy necessary weights
            ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone()
            ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone()
            ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone()
            ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone()
            # Swap the original Attention with our SelfSelfAttention
            self.model.visual.transformer.resblocks[-i].attn = ss_attn
            # Wrap Residual block to handle SelfSelfAttention outputs
            self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i])
        # Modify ViT's forward function
        self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer)
        return

    def encode_text(self, text: list):
        prompts = [f'a photo of a {cls}.' for cls in text]
        tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device)
        text_embedding = self.model.encode_text(tokenized_prompts)
        text_embedding = F.normalize(text_embedding, dim=-1)
        return text_embedding.unsqueeze(0)

    def min_max(self, logits):
        B, num_prompt = logits.shape[:2]
        logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
        logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
        logits = (logits - logits_min) / (logits_max - logits_min)
        return logits

    def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
        """
        :param image: torch.Tensor [1, 3, H, W]
        :param text: list[]
        :param normalize: bool - if True performs min-max normalization
        :param return_ori: bool - if True uses the features from the original visual encoder
        """
        # Image
        W, H = image.shape[-2:]
        feat_gem, feat_ori = self.model.visual(image)
        image_feat = feat_ori if return_ori else feat_gem
        image_feat = F.normalize(image_feat, dim=-1)  # [1, N, dim]

        # Text
        text_embeddings = self.encode_text(text)  # [1, num_prompt, dim]

        # Image-Text matching
        img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2)  # [1, N, num_prompt]
        img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
                                     w=W//self.patch_size, h=H//self.patch_size)  # [1, num_prompt, w, h]

        # Interpolate
        img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear')  # [1, num_prompt, W, H]

        # Heat Maps
        if normalize:
            img_txt_matching = self.min_max(img_txt_matching)
        return img_txt_matching

    def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
        """
        :param image: torch.Tensor [B, 3, H, W]
        :param text: list[list[]]
        :param normalize: bool - if True performs min-max normalization
        :param return_ori: bool - if True uses the features from the original visual encoder
        """
        L = len(text)
        cumm_idx = np.cumsum([len(t) for t in text]).tolist()
        B, _, W, H = image.shape
        assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.'

        # Image
        feat_gem, feat_ori = self.model.visual(image)
        image_feat = feat_ori if return_ori else feat_gem
        image_feat = F.normalize(image_feat, dim=-1)  # [B, N, dim]

        # Text
        flatten_text = [t for sub_text in text for t in sub_text]
        text_embeddings = self.encode_text(flatten_text)  # [B, num_prompt, dim]

        # Image-Text matching
        img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2)  # [B, N, num_prompt]
        img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
                                     w=W // self.patch_size, h=H // self.patch_size)  # [B, num_prompt, w, h]

        # Interpolate
        img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear')  # [B,num_prompt, W, H]

        # Heat Maps
        if normalize:
            img_txt_matching = self.min_max(img_txt_matching)  # [B,num_prompt, W, H]

        # unflatten
        img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1)
        img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)]
        return img_txt_matching

In [52]:
gem_model = GEMWrapper(model=model, tokenizer=tokenizer)

AttributeError: 'TimmModel' object has no attribute 'transformer'

## LeGrad

In [55]:
pip install legrad_torch

[0mCollecting legrad_torch
  Obtaining dependency information for legrad_torch from https://files.pythonhosted.org/packages/9f/a5/117d7d280926a76e48ddb36b40fb1e5b1b15dd1e0e2d25cac88180c57c07/legrad_torch-1.0-py3-none-any.whl.metadata
  Downloading legrad_torch-1.0-py3-none-any.whl.metadata (5.9 kB)
Downloading legrad_torch-1.0-py3-none-any.whl (13 kB)
[0m[33mDEPRECATION: meerkat-ml 0.2.5 has a non-standard dependency specifier multiprocess>=0.70.11Cython>=0.29.21. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of meerkat-ml or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mInstalling collected packages: legrad_torch
Successfully installed legrad_torch-1.0
Note: you may need to restart the kernel to use updated packages.


In [57]:
import legrad

In [58]:
import requests
from PIL import Image
import open_clip
import torch

from legrad import LeWrapper, LePreprocess
from legrad.utils import visualize

In [59]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

In [61]:
model_name = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'

In [65]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

In [66]:
model.eval()
model = LeWrapper(model)

Activating necessary hooks and gradients ....


AttributeError: 'VisionTransformer' object has no attribute 'attn_pool'