In [None]:
# Experiment of using vit-base as backbone and using logistic regression as classification head

import os
import torch
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTModel  # Use ViTModel instead of ViTForImageClassification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np
import time  # For latency measurement
from thop import profile  # For FLOPs calculation 

# === Constants ===
MODEL_NAME = "google/vit-base-patch16-224"  # Pre-trained ViT model
# MODEL_NAME = "google/vit-base-patch16-224-in21k"
BATCH_SIZE = 32
IMAGE_SIZE = 224

# === Load Model and Processor ===
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME, do_rescale=False)
model = ViTModel.from_pretrained(MODEL_NAME).eval()  # Use ViTModel for feature extraction

# === Prepare Dataset ===
transform = Compose([
    Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Resize to 224x224
    ToTensor()                        # Convert to tensor
])

# Path to your dataset
dataset_root = "/Users/vladimirzvyozdkin/Studies_Hildesheim/SRP/ViT B16 replication"  # Specify the path to your data folder

# Load Oxford-IIIT Pets dataset
train_dataset = OxfordIIITPet(root=dataset_root, split="trainval", transform=transform, download=False)
test_dataset = OxfordIIITPet(root=dataset_root, split="test", transform=transform, download=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# === Feature Extraction ===
def extract_features(dataloader, model):
    features, labels = [], []
    with torch.no_grad():  # Disable gradients
        for images, targets in tqdm(dataloader, desc="Extracting features"):
            inputs = feature_extractor(images, return_tensors="pt")
            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state[:, 0, :]  # [CLS] token
            features.append(embeddings.cpu().numpy())
            labels.append(targets.numpy())
    return np.vstack(features), np.concatenate(labels)

print("Extracting features for the training set...")
train_features, train_labels = extract_features(train_loader, model)

print("Extracting features for the test set...")
test_features, test_labels = extract_features(test_loader, model)

# === Train Logistic Regression ===
print("Training logistic regression...")
clf = LogisticRegression(max_iter=500, verbose=1)
clf.fit(train_features, train_labels)

# === Evaluate Accuracy ===
predictions = clf.predict(test_features)
accuracy = accuracy_score(test_labels, predictions)
print(f"Accuracy on the test set: {accuracy:.2%}")

# === Measure Latency ===
print("Measuring latency...")
start_time = time.time()
clf.predict(test_features[:1])  # Measure latency for a single prediction
end_time = time.time()
latency = (end_time - start_time) * 1000  # Convert to milliseconds
print(f"Latency (inference time for one sample): {latency:.2f} ms")

# === Measure FLOPs ===
print("Calculating FLOPs...")
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)  # Create a dummy input for FLOPs calculation
macs, params = profile(model, inputs=(dummy_input,))
flops = macs * 2  # Multiply MACs by 2 to get FLOPs
print(f"FLOPs (floating-point operations): {flops:,}")


Copy of the output

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']  
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.  
Extracting features for the training set...  
Extracting features: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [04:17<00:00,  2.24s/it]  
Extracting features for the test set...  
Extracting features: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [04:41<00:00,  2.44s/it]  
Training logistic regression...  
RUNNING THE L-BFGS-B CODE  

           * * *  

Machine precision = 2.220D-16  
 N =        28453     M =           10  
 This problem is unconstrained.  

At X0         0 variables are exactly at the bounds  

At iterate    0    f=  3.61092D+00    |proj g|=  1.18076D-01  

           * * *  

Tit   = total number of iterations  
Tnf   = total number of function evaluations  
Tnint = total number of segments explored during Cauchy searches  
Skip  = number of BFGS updates skipped  
Nact  = number of active bounds at final generalized Cauchy point  
Projg = norm of the final projected gradient  
F     = final function value  

           * * *  
 
 |N|Tit|Tnf|Tnint|Skip|Nact|Projg|F|
 |-|---|---|-----|----|----|-----|-|
|28453  |   39 |    41   |   1   |  0   |  0   |8.169D-05 |  1.387D-02 | 
  F =   1.3872503921853369E-002
  
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL              
Accuracy on the test set: 92.72%  
Measuring latency...  
Latency (inference time for one sample): 0.53 ms  
Calculating FLOPs...  
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.  
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.  
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.  
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.  
FLOPs (floating-point operations): 33,726,904,320.0  