In [1]:
import json
import numpy as np
import torch
from torchvision import datasets
from transformers import AutoProcessor, ResNetForImageClassification, pipeline

  from .autonotebook import tqdm as notebook_tqdm
2024-04-12 17:05:11.855582: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-12 17:05:11.882139: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-12 17:05:11.882183: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-12 17:05:11.882199: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-12 17:05:1

In [2]:
model_id = "microsoft/resnet-50"
model = ResNetForImageClassification.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

def single_image_classification(image):
    inputs = processor(image, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_label_id = logits.argmax(-1).item()

    predicted_label = model.config.id2label[predicted_label_id]

    return predicted_label




In [3]:
dataset = datasets.ImageFolder(
    root='mini-ImageNet',
)

with open('mini-ImageNet/imagenet_class_index.json') as f:
    class_index = json.load(f)

folder_to_class = {value[0]: value[1] for key, value in class_index.items()}
all_classes = [folder_to_class[folder] for folder in dataset.classes]
all_classes

['house_finch',
 'robin',
 'triceratops',
 'green_mamba',
 'harvestman',
 'toucan',
 'goose',
 'jellyfish',
 'nematode',
 'king_crab',
 'dugong',
 'Walker_hound',
 'Ibizan_hound',
 'Saluki',
 'golden_retriever',
 'Gordon_setter',
 'komondor',
 'boxer',
 'Tibetan_mastiff',
 'French_bulldog',
 'malamute',
 'dalmatian',
 'Newfoundland',
 'miniature_poodle',
 'white_wolf',
 'African_hunting_dog',
 'Arctic_fox',
 'lion',
 'meerkat',
 'ladybug',
 'rhinoceros_beetle',
 'ant',
 'black-footed_ferret',
 'three-toed_sloth',
 'rock_beauty',
 'aircraft_carrier',
 'ashcan',
 'barrel',
 'beer_bottle',
 'bookshop',
 'cannon',
 'carousel',
 'carton',
 'catamaran',
 'chime',
 'clog',
 'cocktail_shaker',
 'combination_lock',
 'crate',
 'cuirass',
 'dishrag',
 'dome',
 'electric_guitar',
 'file',
 'fire_screen',
 'frying_pan',
 'garbage_truck',
 'hair_slide',
 'holster',
 'horizontal_bar',
 'hourglass',
 'iPod',
 'lipstick',
 'miniskirt',
 'missile',
 'mixing_bowl',
 'oboe',
 'organ',
 'parallel_bars',
 '

In [4]:
bart_pipe = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

def classify_string(text):
    result = bart_pipe(text, candidate_labels=all_classes)
    predicted_label = result['labels'][0]
    return predicted_label

In [6]:
pred_labels = np.load('results/pred_labels.npy').tolist()
true_labels = np.load('results/true_labels.npy').tolist()
checkpoint_interval = 500

with torch.no_grad():
    for i in range(5501, 10000):
        image, label = dataset[i]
        # image = image.unsqueeze(0)
        # image = image.to("cpu")
        
        model_text_output = single_image_classification(image)
        model_text_position0 = model_text_output.split(",")[0]

        true_label = all_classes[label]
        if model_text_position0.lower().replace(" ", "_") in (label.lower().replace(" ", "_") for label in all_classes):
            pred_label = model_text_position0.replace(" ", "_").lower()
        else:
            pred_label = classify_string(model_text_output)

        true_labels.append(true_label)
        pred_labels.append(pred_label)

        print(f'{i} | True: {true_label} | Predicted: {pred_label}')

        if i % checkpoint_interval == 0:
            # Save the true and predicted labels
            np.save(f'results/true_labels.npy', true_labels)
            np.save(f'results/pred_labels.npy', pred_labels)

np.save(f'results/true_labels.npy', true_labels)
np.save(f'results/pred_labels.npy', pred_labels)

5501 | True: frying_pan | Predicted: frying_pan
5502 | True: frying_pan | Predicted: frying_pan
5503 | True: frying_pan | Predicted: frying_pan
5504 | True: frying_pan | Predicted: frying_pan
5505 | True: frying_pan | Predicted: frying_pan
5506 | True: frying_pan | Predicted: frying_pan
5507 | True: frying_pan | Predicted: frying_pan
5508 | True: frying_pan | Predicted: upright
5509 | True: frying_pan | Predicted: frying_pan
5510 | True: frying_pan | Predicted: frying_pan
5511 | True: frying_pan | Predicted: frying_pan
5512 | True: frying_pan | Predicted: frying_pan
5513 | True: frying_pan | Predicted: frying_pan
5514 | True: frying_pan | Predicted: frying_pan
5515 | True: frying_pan | Predicted: frying_pan
5516 | True: frying_pan | Predicted: frying_pan
5517 | True: frying_pan | Predicted: wok
5518 | True: frying_pan | Predicted: frying_pan
5519 | True: frying_pan | Predicted: frying_pan
5520 | True: frying_pan | Predicted: frying_pan
5521 | True: frying_pan | Predicted: frying_pan
55

In [None]:
np.save(f'results/true_labels.npy', true_labels)
np.save(f'results/pred_labels.npy', pred_labels)