In [3]:
!pip install datasets -q
!pip install wandb -q
!pip install scipy -q
!pip install matplotlib -q

Successfully installed contourpy-1.2.1 cycler-0.12.1 fonttools-4.51.0 kiwisolver-1.4.5 matplotlib-3.8.4 pyparsing-3.1.2


In [4]:
import os
import wandb
from tqdm import tqdm
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, DefaultDataCollator
from datasets import load_dataset
from scipy.special import softmax
import torch.nn.functional as F
import matplotlib.pyplot as plt
os.environ['WANDB_API_KEY'] = "0aa429d933365fa8e91048dd017a5410ef2a8c51"

In [7]:
ckpt_name = "google/vit-base-patch16-224"
dataset = load_dataset("imagenet-1k", split="validation", streaming=True, trust_remote_code=True).with_format('torch')
print("DATASET DONE")
image_processor = AutoImageProcessor.from_pretrained(ckpt_name)
print("IMAGE PROCESSOR DONE")
model = AutoModelForImageClassification.from_pretrained(ckpt_name)
print("MODEL DONE")

newdataset = dataset.shuffle(seed=42, buffer_size=11_500).take(11500)


DATASET DONE




IMAGE PROCESSOR DONE
MODEL DONE


In [8]:
correct = 0
total = 0

logits_list = []
labels_list = []
broken_ids = []

with torch.no_grad():

    for i, data in enumerate(tqdm(iter(newdataset))):
        try:
            inputs = image_processor(data['image'], return_tensors="pt")
            if inputs['pixel_values'].shape == 4:
                inputs['pixel_values'] = inputs['pixel_values'].squeeze(1)
        except:
            broken_ids.append(i)
            continue
        outputs = model(**inputs)
        logits = outputs.logits
        labels = data['label'].unsqueeze(dim=0)

        logits_list.append(logits)
        labels_list.append(labels)

        output_probs = F.softmax(logits,dim=1)
        probs, predicted = torch.max(output_probs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if len(labels_list)==10000:
            break

    logits_np = torch.cat(logits_list).numpy()
    labels_np = torch.cat(labels_list).numpy()

    softmaxes = softmax(logits_np, axis=1)

print('Broken images, which are not included: %d' % (len(broken_ids)))
print('Accuracy on the 10,000 validation images: %d %%' % (100 * correct / total))
print(f'number of images used {total}')
ece_criterion = ECELoss()
sce_criterion = SCELoss()

print('ECE: %f' % (ece_criterion.loss(logits_np,labels_np, 15)))

print('ECE with probabilties %f' % (ece_criterion.loss(softmaxes,labels_np,15,False)))

print('SCE: %f' % (sce_criterion.loss(logits_np,labels_np, 15)))

conf_hist = ConfidenceHistogram()
plt_test = conf_hist.plot(logits_np,labels_np,title="Confidence Histogram")
plt_test.savefig('conf_histogram_test.png',bbox_inches='tight')
plt_test.show()

rel_diagram = ReliabilityDiagram()
plt_test_2 = rel_diagram.plot(logits_np,labels_np,title="Reliability Diagram")
plt_test_2.savefig('rel_diagram_test.png',bbox_inches='tight')
plt_test_2.show()

0it [27:09, ?it/s]


KeyboardInterrupt: 