In [48]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import AutoModel
from transformers import AutoModelForImageClassification, AutoImageProcessor
from datasets import load_dataset
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)
import pickle
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# Model Names
MODEL_NAMES = {
    "dino_v2": "facebook/dinov2-large",
    "swin": "microsoft/swin-tiny-patch4-window7-224",
    "vit": "google/vit-base-patch16-224"
}

In [37]:
dataset = load_dataset("imagefolder", data_dir="./data/training_data")

def transform_ds(image_processor):
    test_ds = dataset["test"]
    normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
    
    if "height" in image_processor.size:
        size = (image_processor.size["height"], image_processor.size["width"])
        crop_size = size
        max_size = None
    elif "shortest_edge" in image_processor.size:
        size = image_processor.size["shortest_edge"]
        crop_size = (size, size)
        max_size = image_processor.size.get("longest_edge")
    
    val_transforms = Compose(
            [
                Resize(size),
                CenterCrop(crop_size),
                ToTensor(),
                normalize,
            ]
        )
    
    def preprocess_val(example_batch):
        """Apply val_transforms across a batch."""
        example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    test_ds.set_transform(preprocess_val)
    return test_ds

In [57]:
fpth = r'C:/Users/Shankar/Desktop/goat-ipython/goat-vault/goat-vault/01 - Notes/03 - Resources/Codes/CMI/AML/Project/Training/';
for model_type in MODEL_NAMES:
        print(f"Starting evaluation for {model_type}...")
        # Initialize lists to store predicted and actual labels
        predicted_labels = []
        actual_labels = []
        model_name = MODEL_NAMES[model_type]
        image_processor  = AutoImageProcessor.from_pretrained(model_name)
        # Load the model from SafeTensors format
    
        model = AutoModelForImageClassification.from_pretrained(fpth  + model_type + "-finetuned-dermnet", device_map="auto").to(device)
        
        model.eval()
        test_ds = transform_ds(image_processor)
        
        for example in test_ds:
            image = example["image"]
            encoding = image_processor(image.convert("RGB"), return_tensors="pt").to(device)
    
            with torch.no_grad():
                outputs = model(**encoding)
                logits = outputs.logits
        
            predicted_class_idx = logits.argmax(-1).item()
            predicted_labels.append(predicted_class_idx)
            actual_labels.append(example["label"])

        with open(model_type+ '_predictions.pkl', 'wb') as f: 
            pickle.dump([predicted_labels, actual_labels], f)
        # Calculate accuracy
        accuracy = accuracy_score(actual_labels, predicted_labels)
        
        print(f"Test Accuracy for {model_type}: {accuracy:.4f}")

Starting evaluation for dino_v2...
Test Accuracy for dino_v2: 0.7850
Starting evaluation for swin...
Test Accuracy for swin: 0.6953
Starting evaluation for vit...


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Test Accuracy for vit: 0.7324


In [4]:
# Getting back the objects:
with open('dino_v2.pkl') as f:  # Python 3: open(..., 'rb')
    predicted_labels, actual_labels = pickle.load(f)