# ViTalia - Google ViT v/s Swin Transformer

_______________________

In [3]:
import os
import time
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor, SwinForImageClassification, SwinConfig
from sklearn.metrics import accuracy_score
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_dir = '/kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/'

class MalariaDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(["Parasitized", "Uninfected"]):
            class_dir = os.path.join(root_dir, class_name)
            for img_file in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_file))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError):
            print(f"Error loading image {img_path}. Skipping.")
            return self.__getitem__((idx + 1) % len(self)) 

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

full_dataset = MalariaDataset(data_dir, transform=transform)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)

def train_model(model, epochs=3):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    criterion = torch.nn.CrossEntropyLoss()
    
    # Training loop with timing
    start_time = time.time()
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]")

        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}")

    end_time = time.time()
    training_time = end_time - start_time
    return training_time

def evaluate_model(model):
    model.eval()
    predictions, true_labels = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)

            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    return accuracy

print("Initializing ViT model...")
vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=2)
vit_model.classifier = torch.nn.Linear(vit_model.classifier.in_features, 2)

print("\nTraining ViT model...")
vit_time = train_model(vit_model, epochs=5)
vit_accuracy = evaluate_model(vit_model)
print(f"ViT Training Time: {vit_time:.2f} seconds, Accuracy: {vit_accuracy * 100:.2f}%")

print("\nInitializing Swin Transformer model...")
swin_model = SwinForImageClassification(SwinConfig(num_labels=2))

print("\nTraining Swin Transformer model...")
swin_time = train_model(swin_model, epochs=5)
swin_accuracy = evaluate_model(swin_model)
print(f"Swin Training Time: {swin_time:.2f} seconds, Accuracy: {swin_accuracy * 100:.2f}%")

print("\nModel Comparison:")
print(f"ViT Model - Time: {vit_time:.2f}s, Accuracy: {vit_accuracy * 100:.2f}%")
print(f"Swin Model - Time: {swin_time:.2f}s, Accuracy: {swin_accuracy * 100:.2f}%")

Initializing ViT model...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Training ViT model...


Epoch [1/5]:  94%|█████████▍| 1300/1378 [07:56<00:27,  2.81it/s, loss=0.02]   

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [1/5]:  99%|█████████▉| 1369/1378 [08:21<00:03,  2.72it/s, loss=0.058]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [1/5]: 100%|██████████| 1378/1378 [08:25<00:00,  2.73it/s, loss=0.0187]


Epoch [1/5], Average Loss: 0.1111


Epoch [2/5]:  42%|████▏     | 574/1378 [03:24<04:44,  2.82it/s, loss=0.0209] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [2/5]:  72%|███████▏  | 988/1378 [05:53<02:18,  2.81it/s, loss=0.00626]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [2/5]: 100%|██████████| 1378/1378 [08:13<00:00,  2.79it/s, loss=0.0994] 


Epoch [2/5], Average Loss: 0.0660


Epoch [3/5]:  37%|███▋      | 514/1378 [03:03<05:06,  2.82it/s, loss=0.0198] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [3/5]:  79%|███████▊  | 1083/1378 [06:26<01:47,  2.75it/s, loss=0.0024] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [3/5]: 100%|██████████| 1378/1378 [08:12<00:00,  2.80it/s, loss=0.0192] 


Epoch [3/5], Average Loss: 0.0417


Epoch [4/5]:  14%|█▍        | 199/1378 [01:10<07:07,  2.75it/s, loss=0.37]   

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [4/5]:  60%|█████▉    | 826/1378 [04:53<03:15,  2.82it/s, loss=0.00358]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [4/5]: 100%|██████████| 1378/1378 [08:10<00:00,  2.81it/s, loss=0.00169]


Epoch [4/5], Average Loss: 0.0310


Epoch [5/5]:  45%|████▍     | 616/1378 [03:39<04:29,  2.82it/s, loss=0.00245]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [5/5]:  83%|████████▎ | 1139/1378 [06:44<01:26,  2.76it/s, loss=0.0024]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [5/5]: 100%|██████████| 1378/1378 [08:09<00:00,  2.81it/s, loss=0.00931] 


Epoch [5/5], Average Loss: 0.0205
ViT Training Time: 2471.08 seconds, Accuracy: 97.51%

Initializing Swin Transformer model...

Training Swin Transformer model...


Epoch [1/5]:  81%|████████▏ | 1122/1378 [03:31<00:57,  4.47it/s, loss=0.0864]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [1/5]:  97%|█████████▋| 1334/1378 [04:11<00:10,  4.27it/s, loss=0.0362]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [1/5]: 100%|██████████| 1378/1378 [04:19<00:00,  5.31it/s, loss=0.0543]


Epoch [1/5], Average Loss: 0.3423


Epoch [2/5]:  10%|█         | 139/1378 [00:26<04:43,  4.37it/s, loss=0.344] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [2/5]:  62%|██████▏   | 858/1378 [02:40<01:56,  4.46it/s, loss=0.0844] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [2/5]: 100%|██████████| 1378/1378 [04:18<00:00,  5.33it/s, loss=0.288] 


Epoch [2/5], Average Loss: 0.1489


Epoch [3/5]:  24%|██▍       | 333/1378 [01:02<04:09,  4.19it/s, loss=0.28]   

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [3/5]:  86%|████████▌ | 1186/1378 [03:43<00:47,  4.06it/s, loss=0.362]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [3/5]: 100%|██████████| 1378/1378 [04:19<00:00,  5.31it/s, loss=0.167] 


Epoch [3/5], Average Loss: 0.1372


Epoch [4/5]:  79%|███████▉  | 1086/1378 [03:24<01:43,  2.83it/s, loss=0.017]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [4/5]:  90%|████████▉ | 1236/1378 [03:52<00:32,  4.39it/s, loss=0.028]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [4/5]: 100%|██████████| 1378/1378 [04:19<00:00,  5.31it/s, loss=0.256]  


Epoch [4/5], Average Loss: 0.1300


Epoch [5/5]:  15%|█▌        | 211/1378 [00:39<04:20,  4.48it/s, loss=0.166]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [5/5]:  40%|████      | 553/1378 [01:44<03:16,  4.20it/s, loss=0.0118] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [5/5]: 100%|██████████| 1378/1378 [04:19<00:00,  5.31it/s, loss=0.195]  


Epoch [5/5], Average Loss: 0.1251
Swin Training Time: 1296.47 seconds, Accuracy: 96.03%

Model Comparison:
ViT Model - Time: 2471.08s, Accuracy: 97.51%
Swin Model - Time: 1296.47s, Accuracy: 96.03%


-------------------------------