In [1]:
import numpy as np
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from transformers_interpret import ImageClassificationExplainer
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision import transforms

In [2]:
model_name = "e1010101/vit-384-tongue-image-segmented-augmented"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)



In [3]:
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor()
])

In [4]:
image = Image.open("../Samples/segmented/segmented_sample_2.jpg")
image_transformed = transform(image).unsqueeze(0)

In [5]:
idx_to_labels = {
    0: "Crack",
    1: "Red Dots",
    2: "Toothmark"
}

In [6]:
output = model(image_transformed)
output.logits

tensor([[-2.8335, -0.7818,  3.4794]], grad_fn=<AddmmBackward0>)

In [7]:
output_softmax = F.softmax(output.logits, dim=1)
prediction_score, pred_label_idx = torch.topk(output_softmax, 1)

pred_label_idx.squeeze_()
pred_label_idx.item()

2

In [8]:
predicted_label = idx_to_labels[pred_label_idx.item()]
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')

Predicted: Toothmark ( 0.9843311905860901 )


In [11]:
image_classification_explainer = ImageClassificationExplainer(model=model, feature_extractor=feature_extractor, attribution_type="IG")

image_attributions = image_classification_explainer(image_transformed)

print(image_attributions.shape)

(1, 3, 384, 384)


In [12]:
image_classification_explainer.visualize(
    method="heatmap",
    side_by_side=True,
)

AttributeError: 'Tensor' object has no attribute 'astype'