<a href="https://colab.research.google.com/github/shakuvish/chest_xray/blob/main/x_ray.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from google.colab import drive
drive.mount('/content/drive')


ValueError: Mountpoint must not already contain files

In [8]:
!pip install -q timm torchmetrics

import os, torch, timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report, roc_auc_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

from google.colab import drive
drive.mount('/content/drive')


def has_valid_images(folder, valid_exts=('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')):
    for root, _, files in os.walk(folder):
        if any(file.lower().endswith(valid_exts) for file in files):
            return True
    return False

def find_split_folders(base_path):
    split_paths = {}
    for root, dirs, files in os.walk(base_path):
        if '__MACOSX' in root: continue
        folder_name = os.path.basename(root).lower()
        if folder_name in ['train', 'val', 'test'] and len(dirs) >= 1:
            if has_valid_images(root):
                split_paths[folder_name] = root
    return split_paths


base_dataset_path = '/content/drive/MyDrive/ChestXRay2017'
splits = find_split_folders(base_dataset_path)

print("✅ Verified Dataset Structure:")
for k, v in splits.items():
    print(f" - {k}: {v}")


train_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
val_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


train_ds = datasets.ImageFolder(splits['train'], transform=train_tfms)
val_ds = datasets.ImageFolder(splits['val'], transform=val_tfms) if 'val' in splits else train_ds

if 'test' in splits and has_valid_images(splits['test']):
    test_ds = datasets.ImageFolder(splits['test'], transform=val_tfms)
else:
    print("⚠️ No valid test folder found. Using validation set for testing.")
    test_ds = val_ds

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)
test_loader = DataLoader(test_ds, batch_size=32)
class_names = train_ds.classes

print(f"\n✅ Loaded: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test images")
print("📂 Classes:", class_names)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=2).to(device)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_model(epochs=5):
    for epoch in range(epochs):
        model.train()
        running_loss, correct = 0.0, 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
        acc = 100 * correct / len(train_loader.dataset)
        print(f"✅ Epoch {epoch+1}: Loss={running_loss/len(train_loader):.4f}, Accuracy={acc:.2f}%")

train_model(epochs=5)

model.eval()
y_true, y_pred, probs = [], [], []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        y_true.extend(labels.numpy())
        y_pred.extend(outputs.argmax(1).cpu().numpy())
        probs.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().numpy())

print("\n📄 Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))
print("📈 ROC AUC Score:", roc_auc_score(y_true, probs))

def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = val_tfms(image).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        output = model(image)
        prob = torch.softmax(output, dim=1)
        pred = torch.argmax(prob, dim=1).item()
    print(f"🔍 Prediction: {class_names[pred]} (Confidence: {prob[0][pred]:.2f})")

# Example:
# predict_image('/content/drive/MyDrive/ChestXRay2017/chest_xray/test/NORMAL/IM-0001-0001.jpeg')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Verified Dataset Structure:
 - train: /content/drive/MyDrive/ChestXRay2017/chest_xray/train
⚠️ No valid test folder found. Using validation set for testing.

✅ Loaded: 5019 train, 5019 val, 5019 test images
📂 Classes: ['NORMAL', 'PNEUMONIA']


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Epoch 1/5: 100%|██████████| 157/157 [10:29<00:00,  4.01s/it]


✅ Epoch 1: Loss=0.3482, Accuracy=84.72%


Epoch 2/5: 100%|██████████| 157/157 [04:23<00:00,  1.68s/it]


✅ Epoch 2: Loss=0.1390, Accuracy=94.62%


Epoch 3/5: 100%|██████████| 157/157 [04:23<00:00,  1.68s/it]


✅ Epoch 3: Loss=0.1127, Accuracy=95.82%


Epoch 4/5: 100%|██████████| 157/157 [04:21<00:00,  1.67s/it]


✅ Epoch 4: Loss=0.0970, Accuracy=96.47%


Epoch 5/5: 100%|██████████| 157/157 [04:22<00:00,  1.67s/it]


✅ Epoch 5: Loss=0.0867, Accuracy=96.87%

📄 Classification Report:
              precision    recall  f1-score   support

      NORMAL       0.92      0.98      0.95      1136
   PNEUMONIA       0.99      0.97      0.98      3883

    accuracy                           0.97      5019
   macro avg       0.96      0.98      0.96      5019
weighted avg       0.98      0.97      0.98      5019

📈 ROC AUC Score: 0.9965369088079856


In [9]:
torch.save(model.state_dict(), '/content/drive/MyDrive/tb_vit_model.pth')
print("✅ Model saved to Google Drive.")


✅ Model saved to Google Drive.


In [10]:
model.load_state_dict(torch.load('/content/drive/MyDrive/tb_vit_model.pth'))
model.eval()
print("✅ Model loaded.")


✅ Model loaded.


In [11]:
predict_image('/content/drive/MyDrive/ChestXRay2017/chest_xray/train/NORMAL/IM-0115-0001.jpeg')


🔍 Prediction: NORMAL (Confidence: 1.00)
