In [1]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
from urllib.request import urlopen
from PIL import Image

template = "this is a photo of "
labels = [
    "adenocarcinoma histopathology",
    "brain MRI",
    "covid line chart",
    "squamous cell carcinoma histopathology",
    "immunohistochemistry histopathology",
    "bone X-ray",
    "chest X-ray",
    "pie chart",
    "hematoxylin and eosin histopathology",
]

dataset_url = "https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/"
test_imgs = [
    "squamous_cell_carcinoma_histopathology.jpeg",
    "H_and_E_histopathology.jpg",
    "bone_X-ray.jpg",
    "adenocarcinoma_histopathology.jpg",
    "covid_line_chart.png",
    "IHC_histopathology.jpg",
    "chest_X-ray.jpg",
    "brain_MRI.jpg",
    "pie_chart.png",
]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()

context_length = 256

images = torch.stack(
    [preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]
).to(device)
texts = tokenizer([template + l for l in labels], context_length=context_length).to(
    device
)
with torch.no_grad():
    print(model.encode_image(images).shape)
    image_features, text_features, logit_scale = model(images, texts)
    print(f"Shape of image features: {image_features.shape}")

    logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
    sorted_indices = torch.argsort(logits, dim=-1, descending=True)

    logits = logits.cpu().numpy()
    sorted_indices = sorted_indices.cpu().numpy()

top_k = -1

for i, img in enumerate(test_imgs):
    pred = labels[sorted_indices[i][0]]

    top_k = len(labels) if top_k == -1 else top_k
    print(img.split("/")[-1] + ":")
    for j in range(top_k):
        jth_index = sorted_indices[i][j]
        print(f"{labels[jth_index]}: {logits[i][jth_index]}")
    print("\n")

torch.Size([9, 512])
Shape of image features: torch.Size([9, 512])
squamous_cell_carcinoma_histopathology.jpeg:
squamous cell carcinoma histopathology: 0.997437596321106
adenocarcinoma histopathology: 0.001306830788962543
hematoxylin and eosin histopathology: 0.0012188811087980866
immunohistochemistry histopathology: 3.669962825370021e-05
chest X-ray: 1.1548410665251918e-11
brain MRI: 5.123110269344977e-12
pie chart: 2.510576815883958e-12
covid line chart: 8.306049533064741e-13
bone X-ray: 2.356631498290091e-14


H_and_E_histopathology.jpg:
hematoxylin and eosin histopathology: 0.9871463775634766
immunohistochemistry histopathology: 0.012638912536203861
adenocarcinoma histopathology: 0.00014946787268854678
squamous cell carcinoma histopathology: 5.230658280197531e-05
brain MRI: 1.0339951586502139e-05
chest X-ray: 1.7928430224856129e-06
bone X-ray: 6.8211335246815e-07
pie chart: 2.454528669204592e-07
covid line chart: 4.0791800492989694e-11


bone_X-ray.jpg:
bone X-ray: 0.99947971105575

In [6]:
!pip uninstall transformers -y

Found existing installation: transformers 4.24.0
Uninstalling transformers-4.24.0:
  Successfully uninstalled transformers-4.24.0


In [3]:
!pip uninstall transformers -y

Found existing installation: transformers 4.30.2
Uninstalling transformers-4.30.2:
  Successfully uninstalled transformers-4.30.2


In [4]:
!pip install transformers=="4.29.2"

Collecting transformers==4.29.2
  Using cached transformers-4.29.2-py3-none-any.whl.metadata (112 kB)
Using cached transformers-4.29.2-py3-none-any.whl (7.1 MB)
Installing collected packages: transformers
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
medclip 0.0.3 requires transformers<=4.24.0,>=4.23.1, but you have transformers 4.29.2 which is incompatible.
hest 1.1.1 requires transformers>=4.40.2, but you have transformers 4.29.2 which is incompatible.[0m[31m
[0mSuccessfully installed transformers-4.29.2
