# ViTalia - NonTrained ViT 

___________________________

In [2]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
import warnings


warnings.filterwarnings("ignore")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = models.vision_transformer.vit_b_16(weights=None, num_classes=2).to(device)


data_dir = '/kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/'


class MalariaDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(["Parasitized", "Uninfected"]):
            class_dir = os.path.join(root_dir, class_name)
            for img_file in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_file))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError):
            print(f"Error loading image {img_path}. Skipping.")
            return self.__getitem__((idx + 1) % len(self)) 

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

full_dataset = MalariaDataset(data_dir, transform=transform)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 
criterion = torch.nn.CrossEntropyLoss()

num_epochs = 15
patience = 3
best_loss = float('inf')
early_stop_counter = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")


    if avg_loss < best_loss:
        best_loss = avg_loss
        early_stop_counter = 0 
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            break


model.eval()
predictions, true_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        predictions.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)
conf_matrix = confusion_matrix(true_labels, predictions)

print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Test Precision: {precision * 100:.2f}%")
print(f"Test Recall: {recall * 100:.2f}%")
print(f"Test F1 Score: {f1 * 100:.2f}%")
print("Confusion Matrix:")
print(conf_matrix)


Epoch [1/15]:  55%|█████▍    | 755/1378 [05:19<04:25,  2.35it/s, loss=0.061] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [1/15]:  58%|█████▊    | 804/1378 [05:40<04:01,  2.38it/s, loss=0.245] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [1/15]: 100%|██████████| 1378/1378 [09:44<00:00,  2.36it/s, loss=0.343] 


Epoch [1/15], Average Loss: 0.2224


Epoch [2/15]:  70%|██████▉   | 959/1378 [05:54<02:34,  2.71it/s, loss=0.0174]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [2/15]:  76%|███████▋  | 1054/1378 [06:29<01:59,  2.70it/s, loss=0.392] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [2/15]: 100%|██████████| 1378/1378 [08:29<00:00,  2.70it/s, loss=0.0429] 


Epoch [2/15], Average Loss: 0.1665


Epoch [3/15]:  23%|██▎       | 315/1378 [01:57<06:28,  2.74it/s, loss=0.073] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [3/15]:  53%|█████▎    | 725/1378 [04:29<04:02,  2.69it/s, loss=0.0193]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [3/15]: 100%|██████████| 1378/1378 [08:31<00:00,  2.69it/s, loss=0.0365] 


Epoch [3/15], Average Loss: 0.1545


Epoch [4/15]:  17%|█▋        | 238/1378 [01:27<07:03,  2.69it/s, loss=0.0505]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [4/15]:  84%|████████▍ | 1161/1378 [07:10<01:20,  2.71it/s, loss=0.245]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [4/15]: 100%|██████████| 1378/1378 [08:31<00:00,  2.69it/s, loss=0.211] 


Epoch [4/15], Average Loss: 0.1510


Epoch [5/15]:  82%|████████▏ | 1129/1378 [07:00<01:32,  2.70it/s, loss=0.0682]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [5/15]:  92%|█████████▏| 1264/1378 [07:51<00:41,  2.71it/s, loss=0.0377] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [5/15]: 100%|██████████| 1378/1378 [08:34<00:00,  2.68it/s, loss=0.0267]


Epoch [5/15], Average Loss: 0.1482


Epoch [6/15]:  70%|███████   | 968/1378 [05:57<02:30,  2.72it/s, loss=0.328]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [6/15]:  77%|███████▋  | 1065/1378 [06:33<01:57,  2.66it/s, loss=0.318] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [6/15]: 100%|██████████| 1378/1378 [08:29<00:00,  2.70it/s, loss=0.0525]


Epoch [6/15], Average Loss: 0.1423


Epoch [7/15]:  49%|████▉     | 679/1378 [04:11<04:19,  2.69it/s, loss=0.266]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [7/15]: 100%|██████████| 1378/1378 [08:29<00:00,  2.70it/s, loss=0.143]  


Epoch [7/15], Average Loss: 0.1434


Epoch [8/15]:  48%|████▊     | 657/1378 [04:02<04:24,  2.72it/s, loss=0.257]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [8/15]:  54%|█████▍    | 742/1378 [04:34<03:55,  2.70it/s, loss=0.102] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [8/15]: 100%|██████████| 1378/1378 [08:29<00:00,  2.70it/s, loss=0.0796] 


Epoch [8/15], Average Loss: 0.1423


Epoch [9/15]:  24%|██▎       | 325/1378 [02:00<06:37,  2.65it/s, loss=0.188]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [9/15]:  67%|██████▋   | 924/1378 [05:41<02:48,  2.70it/s, loss=0.0424]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [9/15]: 100%|██████████| 1378/1378 [08:29<00:00,  2.70it/s, loss=0.238]  


Epoch [9/15], Average Loss: 0.1433


Epoch [10/15]:   8%|▊         | 107/1378 [00:39<07:44,  2.74it/s, loss=0.045] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [10/15]:  23%|██▎       | 323/1378 [01:59<06:27,  2.72it/s, loss=0.182]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [10/15]: 100%|██████████| 1378/1378 [08:28<00:00,  2.71it/s, loss=0.102]  


Epoch [10/15], Average Loss: 0.1341


Epoch [11/15]:  30%|███       | 414/1378 [02:32<05:55,  2.71it/s, loss=0.00877]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [11/15]:  90%|█████████ | 1242/1378 [07:38<00:50,  2.72it/s, loss=0.178]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [11/15]: 100%|██████████| 1378/1378 [08:28<00:00,  2.71it/s, loss=0.268] 


Epoch [11/15], Average Loss: 0.1301


Epoch [12/15]:  29%|██▉       | 402/1378 [02:28<05:58,  2.73it/s, loss=0.0265] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [12/15]:  35%|███▍      | 477/1378 [02:56<05:32,  2.71it/s, loss=0.0138] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [12/15]: 100%|██████████| 1378/1378 [08:31<00:00,  2.69it/s, loss=0.0138] 


Epoch [12/15], Average Loss: 0.1276


Epoch [13/15]:  46%|████▌     | 637/1378 [03:59<04:46,  2.59it/s, loss=0.318]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [13/15]:  65%|██████▌   | 896/1378 [05:39<02:58,  2.70it/s, loss=0.35]   

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [13/15]: 100%|██████████| 1378/1378 [08:41<00:00,  2.64it/s, loss=0.0222] 


Epoch [13/15], Average Loss: 0.1253


Epoch [14/15]:   8%|▊         | 105/1378 [00:40<07:57,  2.67it/s, loss=0.0437]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [14/15]:  22%|██▏       | 306/1378 [01:57<06:39,  2.68it/s, loss=0.0457] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [14/15]: 100%|██████████| 1378/1378 [08:45<00:00,  2.62it/s, loss=0.0204] 


Epoch [14/15], Average Loss: 0.1255


Epoch [15/15]:  23%|██▎       | 321/1378 [01:58<06:28,  2.72it/s, loss=0.0506] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [15/15]:  38%|███▊      | 524/1378 [03:13<05:15,  2.71it/s, loss=0.323]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [15/15]: 100%|██████████| 1378/1378 [08:30<00:00,  2.70it/s, loss=0.0179] 


Epoch [15/15], Average Loss: 0.1230
Test Accuracy: 96.08%
Test Precision: 95.22%
Test Recall: 97.08%
Test F1 Score: 96.14%
Confusion Matrix:
[[2604  135]
 [  81 2692]]


_______________________________________