In [8]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

In [9]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

In [10]:
def transform(image):
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    return feature_extractor(images=image, return_tensors='pt').pixel_values.squeeze()

In [12]:
train_data = ImageFolder(root='train', transform=transform)
val_data = ImageFolder(root='validation', transform=transform)
test_data = ImageFolder(root='test', transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=2, ignore_mismatched_sizes=True  )
model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [16]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)


In [17]:
import time
import torch
from tqdm import tqdm

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, patience=3, save_best_model=False, verbose=True):
    model.to(device)
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        total_train_loss = 0
        total_train_correct = 0
        start_time = time.time()

        # Training loop
        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Training]', unit='batch'):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs).logits  # For ViT, use logits attribute
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            total_train_correct += torch.sum(preds == labels.data)

        # Validation loop
        model.eval()  # Set the model to evaluation mode
        total_val_loss = 0
        total_val_correct = 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Validation]', unit='batch'):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).logits  # For ViT, use logits attribute
                loss = criterion(outputs, labels)

                total_val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                total_val_correct += torch.sum(preds == labels.data)

        # Calculate average losses and accuracy
        avg_train_loss = total_train_loss / len(train_loader.dataset)
        avg_val_loss = total_val_loss / len(val_loader.dataset)
        train_acc = total_train_correct.double() / len(train_loader.dataset)
        val_acc = total_val_correct.double() / len(val_loader.dataset)
        epoch_duration = time.time() - start_time

        # Print training/validation statistics
        if verbose:
            print(f'Epoch {epoch + 1}/{num_epochs}, Duration: {epoch_duration:.2f}s, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}')

        # Early stopping and saving the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            if save_best_model:
                torch.save(model.state_dict(), 'best_model_vit.pth')
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                if verbose:
                    print(f'Early stopping triggered after {epoch + 1} epochs!')
                model.load_state_dict(best_model_state)
                break

    return model

In [18]:
model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, patience=3, save_best_model=True)

Epoch 1/10 [Training]: 100%|██████████████████████████████████████████████████████| 225/225 [04:07<00:00,  1.10s/batch]
Epoch 1/10 [Validation]: 100%|██████████████████████████████████████████████████████| 48/48 [00:29<00:00,  1.62batch/s]


Epoch 1/10, Duration: 277.35s, Train Loss: 0.0159, Train Acc: 0.8035, Val Loss: 0.0160, Val Acc: 0.8073


Epoch 2/10 [Training]: 100%|██████████████████████████████████████████████████████| 225/225 [03:41<00:00,  1.02batch/s]
Epoch 2/10 [Validation]: 100%|██████████████████████████████████████████████████████| 48/48 [00:26<00:00,  1.84batch/s]


Epoch 2/10, Duration: 247.56s, Train Loss: 0.0157, Train Acc: 0.8030, Val Loss: 0.0152, Val Acc: 0.8073


Epoch 3/10 [Training]: 100%|██████████████████████████████████████████████████████| 225/225 [03:55<00:00,  1.05s/batch]
Epoch 3/10 [Validation]: 100%|██████████████████████████████████████████████████████| 48/48 [00:29<00:00,  1.61batch/s]


Epoch 3/10, Duration: 265.07s, Train Loss: 0.0153, Train Acc: 0.8035, Val Loss: 0.0150, Val Acc: 0.8073


Epoch 4/10 [Training]: 100%|██████████████████████████████████████████████████████| 225/225 [03:53<00:00,  1.04s/batch]
Epoch 4/10 [Validation]: 100%|██████████████████████████████████████████████████████| 48/48 [00:29<00:00,  1.65batch/s]


Epoch 4/10, Duration: 262.31s, Train Loss: 0.0153, Train Acc: 0.8033, Val Loss: 0.0152, Val Acc: 0.8073


Epoch 5/10 [Training]: 100%|██████████████████████████████████████████████████████| 225/225 [03:55<00:00,  1.05s/batch]
Epoch 5/10 [Validation]: 100%|██████████████████████████████████████████████████████| 48/48 [00:30<00:00,  1.59batch/s]


Epoch 5/10, Duration: 266.10s, Train Loss: 0.0153, Train Acc: 0.8035, Val Loss: 0.0150, Val Acc: 0.8073


Epoch 6/10 [Training]: 100%|██████████████████████████████████████████████████████| 225/225 [04:00<00:00,  1.07s/batch]
Epoch 6/10 [Validation]: 100%|██████████████████████████████████████████████████████| 48/48 [00:32<00:00,  1.49batch/s]

Epoch 6/10, Duration: 272.73s, Train Loss: 0.0152, Train Acc: 0.8035, Val Loss: 0.0150, Val Acc: 0.8073
Early stopping triggered after 6 epochs!





In [22]:
torch.save(model.state_dict(), 'final_model_vit.pth')

Test

In [23]:
import numpy as np
from sklearn.metrics import accuracy_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224', 
    num_labels=2,
    ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load('final_model_vit.pth'))
model.to(device)
model.eval()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [25]:
labels = []
scores = []

with torch.no_grad():
    for inputs, batch_labels in test_loader:
        inputs = inputs.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(inputs).logits
        probabilities = torch.nn.functional.softmax(outputs, dim=1)

        labels.extend(batch_labels.cpu().numpy())
        scores.extend(probabilities.cpu().numpy())

# Convert scores and labels to numpy arrays for metric calculations
labels = np.array(labels)
scores = np.array(scores)

In [29]:
def calculate_metrics(labels, scores, far_target=1e-3):
    labels = np.array(labels)
    scores = np.array(scores)[:, 1]  # Take the probabilities of the positive class

    # Accuracy
    predictions = (scores > 0.5).astype(int) 
    accuracy = accuracy_score(labels, predictions)

    # Calculate ROC Curve and EER
    fpr, tpr, thresholds = roc_curve(labels, scores)
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

    # Find TAR at specified FAR
    far_index = np.where(fpr <= far_target)[0][-1]
    tar_at_far = tpr[far_index]

    return accuracy, eer, tar_at_far

# Calculate the metrics
accuracy, eer, tar_at_far = calculate_metrics(labels, scores)
far_target = 1e-3

print(f'Accuracy: {accuracy:.4f}')
print(f'EER: {eer:.4f}')
print(f'TAR at FAR={far_target}: {tar_at_far:.4f}')

Accuracy: 0.1874
EER: 0.5000
TAR at FAR=0.001: 0.0000
