In [3]:
import torch
print(torch.cuda.is_available())   # Should return True
print(torch.cuda.get_device_name(0))  # Should show RTX 4060

True
NVIDIA GeForce RTX 4060 Laptop GPU


In [6]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# Stage 2 - Paths
stage2_train_dir = "../Dataset/stage2/train"
stage2_val_dir = "../Dataset/stage2/val"
stage2_test_dir = "../Dataset/stage2/test"

# Transforms (same as stage 1)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Datasets and loaders
train_dataset2 = datasets.ImageFolder(stage2_train_dir, transform=transform)
val_dataset2 = datasets.ImageFolder(stage2_val_dir, transform=transform)
test_dataset2 = datasets.ImageFolder(stage2_test_dir, transform=transform)

train_loader2 = DataLoader(train_dataset2, batch_size=32, shuffle=True)
val_loader2 = DataLoader(val_dataset2, batch_size=32)
test_loader2 = DataLoader(test_dataset2, batch_size=32)

# Class names
class_names2 = train_dataset2.classes
print("Stage 2 Classes:", class_names2)



Stage 2 Classes: ['Bacterial Pneumonia', 'Viral Pneumonia']


In [4]:
from transformers import ViTForImageClassification
# Define model
model2 = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=2  # Binary classification
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model2.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [8]:
from tqdm import tqdm
# Optimizer and loss
optimizer2 = torch.optim.Adam(model2.parameters(), lr=2e-5)
criterion2 = torch.nn.CrossEntropyLoss()

# Early stopping params
best_val_loss2 = float('inf')
patience = 3
trigger_times = 0

# Training loop
for epoch in range(1, 16):
    model2.train()
    train_loss = 0.0

    loop = tqdm(train_loader2, desc=f"Epoch {epoch}/15")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer2.zero_grad()
        outputs = model2(images).logits
        loss = criterion2(outputs, labels)
        loss.backward()
        optimizer2.step()

        train_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_train_loss = train_loss / len(train_loader2)

    # Validation
    model2.eval()
    val_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in val_loader2:
            images, labels = images.to(device), labels.to(device)
            outputs = model2(images).logits
            loss = criterion2(outputs, labels)
            val_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = val_loss / len(val_loader2)
    val_accuracy = 100 * correct / total

    print(f"\nEpoch {epoch} finished. Avg Train Loss: {avg_train_loss:.4f}")
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

    # Check for improvement
    if avg_val_loss < best_val_loss2:
        best_val_loss2 = avg_val_loss
        torch.save(model2.state_dict(), "../Model/best_model_stage2.pth")
        print(f" Best model saved at epoch {epoch}.")
        trigger_times = 0
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f" Early stopping triggered after {epoch} epochs.")
            break


Epoch 1/15: 100%|██████████| 76/76 [01:53<00:00,  1.49s/it, loss=0.691]



Epoch 1 finished. Avg Train Loss: 0.5870
Validation Loss: 0.5308, Accuracy: 73.07%
✅ Best model saved at epoch 1.


Epoch 2/15: 100%|██████████| 76/76 [01:23<00:00,  1.10s/it, loss=0.577]



Epoch 2 finished. Avg Train Loss: 0.4674
Validation Loss: 0.4730, Accuracy: 78.55%
✅ Best model saved at epoch 2.


Epoch 3/15: 100%|██████████| 76/76 [01:23<00:00,  1.10s/it, loss=0.568]



Epoch 3 finished. Avg Train Loss: 0.3515
Validation Loss: 0.4376, Accuracy: 80.92%
✅ Best model saved at epoch 3.


Epoch 4/15: 100%|██████████| 76/76 [01:05<00:00,  1.16it/s, loss=0.0834]



Epoch 4 finished. Avg Train Loss: 0.2535
Validation Loss: 0.4766, Accuracy: 78.80%


Epoch 5/15: 100%|██████████| 76/76 [01:28<00:00,  1.17s/it, loss=0.0433]



Epoch 5 finished. Avg Train Loss: 0.1672
Validation Loss: 0.5560, Accuracy: 79.18%


Epoch 6/15: 100%|██████████| 76/76 [01:21<00:00,  1.07s/it, loss=0.0763]



Epoch 6 finished. Avg Train Loss: 0.1061
Validation Loss: 0.6319, Accuracy: 79.05%
🛑 Early stopping triggered after 6 epochs.


In [7]:
# Load best model
model2.load_state_dict(torch.load("../Model/best_model_stage2.pth",weights_only=True))
model2.eval()

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader2:
        images, labels = images.to(device), labels.to(device)
        outputs = model2(images).logits
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_accuracy = 100 * correct / total
print(f" Final Test Accuracy (Stage 2): {test_accuracy:.2f}%")


 Final Test Accuracy (Stage 2): 78.91%


HyperParameter Tuning

In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import optuna
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from transformers import ViTForImageClassification
from optuna.exceptions import TrialPruned
import time
from tqdm import tqdm

def objective(trial):
    # Suggest hyperparameters
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])
    num_epochs = trial.suggest_int('num_epochs', 5, 15)
    hidden_size = trial.suggest_int('hidden_size', 128, 512, step=64)

    # Model setup
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224-in21k",
        num_labels=2
    )
    model.classifier = nn.Linear(model.classifier.in_features, 2)
    model.to(device)

    # Optimizer & loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # DataLoaders with performance tweak
    train_loader = DataLoader(train_dataset2, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset2, batch_size=batch_size, num_workers=4, pin_memory=True)

    start_time = time.time()

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0
        for images, labels in tqdm(train_loader, desc=f"[Trial {trial.number}] Epoch {epoch+1}/{num_epochs}", leave=False):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation
        model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs.logits, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        accuracy = accuracy_score(all_labels, all_preds)
        trial.report(accuracy, epoch)

        # Prune if not promising
        if trial.should_prune():
            raise TrialPruned()

    end_time = time.time()
    print(f"⏱️ Trial {trial.number} completed in {(end_time - start_time)/60:.2f} minutes with accuracy {accuracy:.4f}")

    return accuracy


In [None]:
study = optuna.create_study(direction='maximize')  # We want to maximize accuracy
study.optimize(objective, n_trials=30)
print(f"Best hyperparameters: {study.best_params}")