In [8]:
# Experiment of using vit-small as backbone and using logistic regression as classification head


import torch
import numpy as np
import time
from tqdm import tqdm
from thop import profile
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from torchvision.datasets import Flowers102
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torch.utils.data import DataLoader, random_split, ConcatDataset
from transformers import ViTModel

# Constants
MODEL_NAME = "facebook/dino-vits16"  # ViT-Small from DINO
BATCH_SIZE = 32
IMAGE_SIZE = 224
RANDOM_SEED = 42

# Load Model
print(f"Loading model: {MODEL_NAME}")
model = ViTModel.from_pretrained(MODEL_NAME).eval()

# Standard normalization values for ImageNet pre-trained models
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Prepare Dataset
transform = Compose([
    Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Resize to 224x224
    ToTensor(),
    Normalize(mean=mean, std=std)
])

dataset_root = "./data" 

print("Loading Oxford Flowers 102 dataset...")
# Try to load all splits
train_dataset = Flowers102(root=dataset_root, split="train", transform=transform, download=True)
val_dataset = Flowers102(root=dataset_root, split="val", transform=transform, download=True)
test_dataset = Flowers102(root=dataset_root, split="test", transform=transform, download=True)
    
# Combine all splits to create a complete dataset
complete_dataset = ConcatDataset([train_dataset, val_dataset, test_dataset])
    
# Calculate sizes for 80-20 split
total_size = len(complete_dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size
    
# Create new random splits (80% train, 20% test)
train_dataset, test_dataset = random_split(
    complete_dataset,
    [train_size, test_size],
    generator=torch.Generator().manual_seed(RANDOM_SEED)
)
   
print(f"Training samples: {len(train_dataset)} (80%)")
print(f"Test samples: {len(test_dataset)} (20%)")
print(f"Total samples: {total_size}")
    
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Loading model: facebook/dino-vits16


Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vits16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading Oxford Flowers 102 dataset...
Training samples: 6551 (80%)
Test samples: 1638 (20%)
Total samples: 8189


In [9]:
# Feature Extraction
def extract_features(dataloader, model):
    features = []
    labels = []
    model.eval()
    with torch.no_grad():  # Disable gradients
        for batch in tqdm(dataloader, desc="Extracting features"):
            # Get images and labels
            # Since we're using a ConcatDataset and random_split, 
            # we need to handle different dataset structures
            if isinstance(batch, list):
                # For custom datasets that return a list
                images, targets = batch
            elif hasattr(batch, 'data') and hasattr(batch, 'targets'):
                # For some torchvision datasets
                images, targets = batch.data, batch.targets
            else:
                # Default unpacking
                try:
                    images, targets = batch
                except:
                    print(f"Unexpected batch format: {type(batch)}")
                    continue
            
            # Make sure targets is a tensor
            if not isinstance(targets, torch.Tensor):
                targets = torch.tensor(targets)
            # Move to device
            images = images.to(device)
            # Forward pass through the model
            outputs = model(pixel_values=images)
            # Extract [CLS] token embeddings (first token)
            embeddings = outputs.last_hidden_state[:, 0, :]
            # Store features and labels
            features.append(embeddings.cpu().numpy())
            labels.append(targets.cpu().numpy())

        
    return np.vstack(features), np.concatenate(labels)

print("\nExtracting features for the training set...")
train_features, train_labels = extract_features(train_loader, model)

print("Extracting features for the test set...")
test_features, test_labels = extract_features(test_loader, model)
    
print(f"Train features shape: {train_features.shape}")
print(f"Test features shape: {test_features.shape}")


Extracting features for the training set...


Extracting features: 100%|██████████| 205/205 [02:31<00:00,  1.36it/s]


Extracting features for the test set...


Extracting features: 100%|██████████| 52/52 [00:37<00:00,  1.40it/s]

Train features shape: (6551, 384)
Test features shape: (1638, 384)





In [10]:
# Train Logistic Regression 
print("\nTraining logistic regression...")
clf = LogisticRegression(max_iter=2000, verbose=1, n_jobs=-1, random_state=RANDOM_SEED)
clf.fit(train_features, train_labels)
    
# Evaluate Accuracy 
predictions = clf.predict(test_features)
accuracy = accuracy_score(test_labels, predictions)
print(f"Accuracy on the test set: {accuracy:.2%}")
    
# Measure Latency
print("Measuring latency...")
start_time = time.time()
_ = clf.predict(test_features[:1])  # scikit-learn classifier latency
end_time = time.time()
latency = (end_time - start_time) * 1000
print(f"Latency (inference time for one sample): {latency:.2f} ms")

# Measure FLOPs of ViT 
print("Calculating FLOPs of ViT model...")
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)
model = model.to(device)
macs, params = profile(model, inputs=(dummy_input,), verbose=False)
flops = macs * 2
print(f"FLOPs (floating-point operations): {flops:,}")


Training logistic regression...


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =        39270     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  4.62497D+00    |proj g|=  5.12441D-01


 This problem is unconstrained.



At iterate   50    f=  3.76031D-03    |proj g|=  1.74471D-04

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
39270     71     77      1     0     0   4.533D-05   3.104D-03
  F =   3.1039340846413079E-003

CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL            
Accuracy on the test set: 97.99%
Measuring latency...
Latency (inference time for one sample): 0.63 ms
Calculating FLOPs of ViT model...
FLOPs (floating-point operations): 8,497,093,632.0
