# Fine-tune ViT for Chest X-ray Classification (Local, Feature Extraction)

This notebook fine-tunes a Vision Transformer (ViT) on the Chest X-ray dataset (Pneumonia vs Normal) using your local dataset at `mediscan/data/chest_xray`. Only the classification head is trained (feature extraction).

## 1. Install Dependencies

Install Hugging Face Transformers, Datasets, and other requirements.

In [4]:
%pip install transformers datasets pillow torch torchvision




## 2. Set Dataset Paths

Set the correct local paths for your dataset.

In [5]:
import os

base_dir = r"C:\Users\Tejas\OneDrive\Desktop\MedIntuit\mediscan\data\chest_xray"
train_path = r"C:\Users\Tejas\OneDrive\Desktop\MedIntuit\mediscan\data\chest_xray\train"
val_path = r"C:\Users\Tejas\OneDrive\Desktop\MedIntuit\mediscan\data\chest_xray\val"
print("Train dir:", train_path)
print("Val dir:", val_path)

Train dir: C:\Users\Tejas\OneDrive\Desktop\MedIntuit\mediscan\data\chest_xray\train
Val dir: C:\Users\Tejas\OneDrive\Desktop\MedIntuit\mediscan\data\chest_xray\val


## 3. Load Dataset

Load the train and validation splits using Hugging Face Datasets.

In [7]:
from datasets import load_dataset

train_dataset = load_dataset('imagefolder', data_dir=train_path, split='train')
val_dataset = load_dataset('imagefolder', data_dir=val_path, split='train')

print("Sample train_dataset[0]:", train_dataset[0])
print("train_dataset features:", train_dataset.features)

Sample train_dataset[0]: {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=L size=2090x1858 at 0x148C9476DD0>, 'label': 0}
train_dataset features: {'image': Image(mode=None, decode=True, id=None), 'label': ClassLabel(names=['NORMAL', 'PNEUMONIA'], id=None)}


## 4. Prepare Labels

Get class names and label mappings.

In [8]:
labels = train_dataset.features['label'].names
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}

print("Labels:", labels)

Labels: ['NORMAL', 'PNEUMONIA']


## 5. Ensure All Images Are RGB

Convert all grayscale images to RGB before training.

In [9]:
from PIL import Image
import numpy as np

def ensure_rgb(img):
    if isinstance(img, Image.Image):
        if img.mode != "RGB":
            return img.convert("RGB")
        return img
    if isinstance(img, np.ndarray):
        if img.ndim == 2:  # grayscale
            return Image.fromarray(img).convert("RGB")
        elif img.ndim == 3 and img.shape[2] == 1:
            return Image.fromarray(img.squeeze(-1)).convert("RGB")
        elif img.ndim == 3 and img.shape[2] == 3:
            return Image.fromarray(img)
    return img

def fix_dataset_rgb(dataset):
    for i in range(len(dataset)):
        img = dataset[i]['image']
        dataset[i]['image'] = ensure_rgb(img)

fix_dataset_rgb(train_dataset)
fix_dataset_rgb(val_dataset)

## 6. Load Pretrained ViT and Freeze Base

Load the ViT model and processor, and freeze the base model for head-only training.

In [10]:
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch

model_name = "google/vit-base-patch16-224"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

# Replace the classification head for your number of classes
model.classifier = torch.nn.Linear(model.classifier.in_features, len(labels))
model.config.num_labels = len(labels)
model.config.id2label = id2label
model.config.label2id = label2id

# Freeze base ViT parameters
for param in model.vit.parameters():
    param.requires_grad = False




Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


## 7. Preprocessing Function

Define a transform to preprocess images for ViT.

In [11]:
def transform(example_batch):
    inputs = processor(images=example_batch['image'], return_tensors='pt')
    inputs['label'] = example_batch['label']
    return inputs

train_dataset.set_transform(transform)
val_dataset.set_transform(transform)

## 8. Training Arguments

Set up Hugging Face Trainer arguments.

In [12]:
# Install accelerate package for faster training
%pip install accelerate>=0.26.0 transformers datasets evaluate

Note: you may need to restart the kernel to use updated packages.


In [13]:
# Then import the necessary classes
from transformers import TrainingArguments

# Now define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
)

In [15]:
import os

# Use the existing train_path variable or construct it from base_dir
train_path = os.path.join(base_dir, "train")
print("Train classes:", os.listdir(train_path))

for cls in os.listdir(train_path):
    cls_path = os.path.join(train_path, cls)
    print(f"Found {len(os.listdir(cls_path))} images in {cls}")


Train classes: ['NORMAL', 'PNEUMONIA']
Found 1341 images in NORMAL
Found 3875 images in PNEUMONIA


## 9. Train the Model

Use Hugging Face Trainer to train only the classification head.

In [40]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
)

## 10. Save Model

Save the fine-tuned model and processor locally.

In [41]:
model_path = "./fine_tuned_model"
model.save_pretrained(model_path)
processor.save_pretrained(model_path)

['./fine_tuned_model\\preprocessor_config.json']