In [1]:
import os
import cv2
import numpy as np
import torch
from transformers import ViTImageProcessor, ViTForImageClassification, Trainer
from datasets import Dataset
from sklearn.metrics import accuracy_score

  from .autonotebook import tqdm as notebook_tqdm





In [2]:
# Load the fine-tuned model
from transformers import ViTForImageClassification

model_directory = "vit-base-beans"# Path to the fine-tuned model
model = ViTForImageClassification.from_pretrained(model_directory)



In [5]:
def get_images(path):
    image_path = []
    struck = []
    for i in os.listdir(path):
        for j in os.listdir(os.path.join(path, i)):
            if j == ".ipynb_checkpoints":
                continue
            image_path.append(os.path.join(path, i, j))
            struck.append(1 if i == "struck" else 0)
    return image_path, struck

In [6]:
validation_path = "4765063/validation/validation"

In [7]:
validation_image_path, validation_struck = get_images(validation_path)

In [8]:
def image_preprocessing(image_path):
    images = []
    for i in image_path:
        img = cv2.imread(i)
        gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        _, binary_image = cv2.threshold(gray_image, 200, 255, cv2.THRESH_BINARY)
        binary_image = cv2.resize(binary_image, (224, 224))
        binary_image = cv2.merge([binary_image, binary_image, binary_image])
        binary_image = binary_image / 255
        binary_image = torch.from_numpy(binary_image)
        images.append(binary_image)
    return images

In [9]:
# Prepare the validation dataset
validation_images = image_preprocessing(validation_image_path)
validation_images = torch.stack(validation_images)
validation_images = validation_images.permute(0, 3, 1, 2)

In [10]:
validation = {
    'image': validation_images,
    'label': validation_struck
}

In [11]:
validation_dataset = Dataset.from_dict(validation)

In [12]:
len(validation_dataset[0])

2

In [16]:
def collate_fn(batch):
    pixel_values = torch.stack([torch.from_numpy(np.array(item['image'])) for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return {'pixel_values': pixel_values, 'labels': labels}

In [17]:
# Create a Trainer instance for evaluation
from transformers import TrainingArguments

In [18]:
training_args = TrainingArguments(
    output_dir="./vit-base-eval",
    per_device_eval_batch_size=16,
    remove_unused_columns=False,
    push_to_hub=False,
)

In [19]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    eval_dataset=validation_dataset,
)

In [20]:
# Evaluate on the validation dataset
predictions = trainer.predict(validation_dataset)


In [21]:
# Convert logits to predicted labels
predicted_labels = np.argmax(predictions.predictions, axis=1)


In [22]:
# Ground truth labels
true_labels = np.array(validation_struck)


In [23]:
# Calculate accuracy
accuracy = accuracy_score(true_labels, predicted_labels)



In [24]:
print(f"Validation Accuracy: {accuracy:.4f}")

Validation Accuracy: 0.9960
