# ViTalia - Malaria Detection using Vision Transformers

______________________________

### Malaria Detection

Malaria is a serious and life-threatening disease caused by *Plasmodium* parasites, transmitted to humans through the bites of infected female *Anopheles* mosquitoes. Though malaria is preventable and curable, it remains a major global health concern.

- **Global Burden**: In 2017, there were approximately 219 million cases of malaria across 90 countries, leading to around 435,000 deaths. The majority of malaria cases and fatalities occurred in the WHO African Region, which bears 92% of cases and 93% of deaths.
- **Affected Populations**: Malaria has a disproportionate impact on children under five and pregnant women, particularly in sub-Saharan Africa. Vulnerable populations in regions with inadequate access to healthcare are at higher risk of severe complications and death.

### Understanding Malaria and its Parasites

Malaria is caused by *Plasmodium* parasites, which are transmitted through "malaria vectors"—infected female *Anopheles* mosquitoes. Five parasite species infect humans, with *P. falciparum* and *P. vivax* posing the greatest risk due to their potential for severe infection and spread.

1. **P. falciparum**: The deadliest species, responsible for the majority of malaria-related deaths. Found mainly in sub-Saharan Africa.
2. **P. vivax**: Known for causing recurring malaria infections, it is common in Asia and the Americas and poses unique challenges in treatment.

### Challenges in Malaria Diagnosis

Diagnosing malaria can be complex and challenging, especially in areas where it is no longer endemic. Healthcare providers in these regions may not be accustomed to considering malaria as a potential diagnosis, and laboratory staff may lack experience in identifying the parasites under a microscope. Malaria symptoms—such as fever, chills, and headache—are often mild and easily mistaken for other illnesses, especially in non-immune individuals. Without prompt treatment, *P. falciparum* infections can rapidly progress to severe illness, often within 24 hours, leading to life-threatening complications.

### Methods of Malaria Detection

#### Microscopic Diagnosis

Microscopy remains the gold standard for malaria diagnosis. This involves preparing a blood smear from the patient’s blood, staining it to highlight the parasites, and examining it under a microscope. While highly effective, this method requires high-quality reagents, reliable microscopes, and skilled laboratory staff to detect and differentiate malaria species accurately. Errors can occur if the microscopy equipment or techniques are suboptimal, affecting diagnosis quality.

#### Rapid Diagnostic Tests (RDTs)

Rapid Diagnostic Tests (RDTs) offer an alternative to microscopy, especially in remote or resource-limited areas. These tests are easy to use, provide quick results, and don’t require extensive training, making them suitable for low-resource settings. However, RDTs can vary in accuracy and may not detect all *Plasmodium* species effectively.

#### Molecular Methods

Advanced molecular methods like Polymerase Chain Reaction (PCR) provide highly sensitive and specific malaria diagnoses. Although PCR is less common in field settings due to its cost and complexity, it is used in reference laboratories and research settings for identifying parasite species and drug resistance markers. 

![Malaria](https://cdn1.sph.harvard.edu/wp-content/uploads/2015/03/Malaria-cells_CDC.jpg)

### References

- [WHO Fact Sheet on Malaria](https://www.who.int/news-room/fact-sheets/detail/malaria)
- [CDC Malaria Diagnosis and Treatment](https://www.cdc.gov/malaria/diagnosis_treatment/diagnosis.html)


In [1]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
import numpy as np
import warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=2)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

model.classifier = torch.nn.Linear(model.classifier.in_features, 2)
model.to(device)

data_dir = '/kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/'

class MalariaDataset(Dataset):
    def __init__(self, root_dir, feature_extractor, transform=None):
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(["Parasitized", "Uninfected"]):
            class_dir = os.path.join(root_dir, class_name)
            for img_file in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_file))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError):
            print(f"Error loading image {img_path}. Skipping.")
            return self.__getitem__((idx + 1) % len(self)) 

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

full_dataset = MalariaDataset(data_dir, feature_extractor, transform=transform)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)


optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
criterion = torch.nn.CrossEntropyLoss()

num_epochs = 15
patience = 3
best_val_loss = np.inf
early_stop_count = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=loss.item())

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            val_loss += criterion(outputs, labels).item()

    avg_val_loss = val_loss / len(test_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stop_count = 0
        torch.save(model.state_dict(), "best_model.pth") 
    else:
        early_stop_count += 1
        if early_stop_count >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

model.load_state_dict(torch.load("best_model.pth"))

predictions, true_labels = [], []
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)
conf_matrix = confusion_matrix(true_labels, predictions)

print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Test Precision: {precision * 100:.2f}%")
print(f"Test Recall: {recall * 100:.2f}%")
print(f"Test F1 Score: {f1 * 100:.2f}%")
print("Confusion Matrix:")
print(conf_matrix)


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Epoch [1/15]:  82%|████████▏ | 1131/1378 [07:44<01:42,  2.42it/s, loss=0.0125] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [1/15]:  93%|█████████▎| 1285/1378 [08:48<00:38,  2.43it/s, loss=0.0265] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [1/15]: 100%|██████████| 1378/1378 [09:27<00:00,  2.43it/s, loss=0.00759]


Epoch [1/15], Training Loss: 0.1090
Epoch [1/15], Validation Loss: 0.0755


Epoch [2/15]:  16%|█▌        | 217/1378 [01:16<06:49,  2.84it/s, loss=0.318]  

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [2/15]:  56%|█████▌    | 766/1378 [04:31<03:35,  2.84it/s, loss=0.00867]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [2/15]: 100%|██████████| 1378/1378 [08:08<00:00,  2.82it/s, loss=0.00614]


Epoch [2/15], Training Loss: 0.0633
Epoch [2/15], Validation Loss: 0.0791


Epoch [3/15]:  74%|███████▍  | 1024/1378 [06:03<02:05,  2.83it/s, loss=0.0256] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [3/15]:  88%|████████▊ | 1208/1378 [07:09<01:00,  2.82it/s, loss=0.0274] 

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [3/15]: 100%|██████████| 1378/1378 [08:10<00:00,  2.81it/s, loss=0.179]  


Epoch [3/15], Training Loss: 0.0403
Epoch [3/15], Validation Loss: 0.0765


Epoch [4/15]:  11%|█▏        | 157/1378 [00:56<07:11,  2.83it/s, loss=0.00271]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Parasitized/Thumbs.db. Skipping.


Epoch [4/15]:  85%|████████▍ | 1168/1378 [06:55<01:14,  2.83it/s, loss=0.00304]

Error loading image /kaggle/input/cell-images-for-detecting-malaria/cell_images/cell_images/Uninfected/Thumbs.db. Skipping.


Epoch [4/15]: 100%|██████████| 1378/1378 [08:10<00:00,  2.81it/s, loss=0.363]  


Epoch [4/15], Training Loss: 0.0287
Epoch [4/15], Validation Loss: 0.0885
Early stopping at epoch 4
Test Accuracy: 97.53%
Test Precision: 96.92%
Test Recall: 98.18%
Test F1 Score: 97.55%
Confusion Matrix:
[[2673   86]
 [  50 2703]]


_________________________