**Medical Image Classification with MedNIST Dataset**

End-to-End training and evaluation based on the MedNIST dataset
- Create the dataset and use Transforms to preprocess the images data
- Use DenseNet (Monai) for classification
- Train the model with PyTorch and evaluate on test dataset.

In [None]:
!pip install monai
!pip install monai-weekly
!pip install ignite
!python -c "import monai; print(monai.__version__)"
!python -c "import monai" || pip install -q "monai-weekly[nibabel, ignite, pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

Successfully installed monai-1.5.0


In [None]:
import os, shutil, tempfile, PIL, logging, sys
import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.metrics import classification_report
from monai.apps import download_and_extract
from monai.data import decollate_batch, DataLoader
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121, densenet121
from monai.config import print_config
from monai.transforms import ( Activations, EnsureChannelFirst, AsDiscrete, Compose,
    LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity,)
from monai.utils import set_determinism
from monai.engines import SupervisedTrainer
from monai.handlers import StatsHandler
from monai.inferers import SimpleInferer
from monai.networks import eval_mode
print_config()

!mkdir monai
os.environ['monai'] = '/content/monai'
dir = os.environ.get('monai')
rootdir = tempfile.mkdtemp() if dir is None else dir
rootdir

In [None]:
''' MedNIST dataset gathers several sets, X-ray datasets '''
resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(rootdir, "MedNIST.tar.gz")
datadir = os.path.join(rootdir, "MedNIST")
download_and_extract(resource, compressed_file, rootdir, md5)

MedNIST.tar.gz: 59.0MB [00:03, 18.4MB/s]                            

2025-06-28 14:21:05,571 - INFO - Downloaded: /content/monai/MedNIST.tar.gz





2025-06-28 14:21:05,678 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2025-06-28 14:21:05,679 - INFO - Writing into directory: /content/monai.


In [None]:
''' seed/deterministic training for reproduceability '''
set_determinism(seed=0)

In [None]:
#rm -rf  monai/MedNIST/README.md

In [None]:
''' read images from folders '''
class_names = sorted(os.listdir(datadir))
num_classes = len(class_names)
imgfiles = [
  [os.path.join(datadir, class_names[i], x)
    for x in os.listdir(os.path.join(datadir, class_names[i]))]
      for i in range(num_classes)
]
class_names, num_classes

(['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT'], 6)

In [None]:
num_each_folder = [len(imgfiles[i]) for i in range(num_classes)]
imgfiles_lst = []
image_class = []
for i in range(num_classes):
  imgfiles_lst.extend(imgfiles[i])
  image_class.extend([i]* num_each_folder[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(imgfiles_lst[0]).size
num_each_folder, num_total, image_width, image_height
print(f'total image count: {num_total}')
print(f'labe; dimensions: {image_width}, {image_height}')
print(f'label names: {class_names}')
print(f'label counts: {num_each_folder}')

total image count: 58954
labe; dimensions: 64, 64
label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
label counts: [10000, 8954, 10000, 10000, 10000, 10000]


In [None]:
''' training, validation, and test data list '''
val_percent, test_percent = .1, .1
length = len(imgfiles_lst)
indices = np.arange(length)
np.random.shuffle(indices)
test_split = int(test_percent *length)
val_split = int(val_percent *length) + test_split
test_indices, val_indices = indices[:test_split], indices[test_split:val_split]
train_indices = indices[val_split:]

train_x = [imgfiles_lst[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [imgfiles_lst[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [imgfiles_lst[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]

In [None]:
''' transformations '''
train_transforms = Compose([
  LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity(),
  RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
  RandFlip(spatial_axis=0, prob=0.5),
])

val_transforms = Compose(
    [LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_classes)])

In [None]:
class MedDataset(torch.utils.data.Dataset):
  def __init__(self, image_files, labels, transforms):
    self.image_files = image_files
    self.labels = labels
    self.transforms = transforms

  def __len__(self):
    return len(self.image_files)

  def __getitem__(self, index):
    return self.transforms(self.image_files[index]), self.labels[index]

train_ds = MedDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True) #, num_workers=4)

val_ds = MedDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=300, shuffle=True) #, num_workers=4)

test_ds = MedDataset(test_x, test_y, val_transforms)
test_loader = DataLoader(test_ds, batch_size=300, shuffle=True) #, num_workers=4)

In [None]:
''' DenseNet121 Network and Optimizer '''
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_classes)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
epochs = 5
val_interval = 1
auc = ROCAUCMetric()

In [None]:
''' Model Training '''
best_metric=-1 ; best_metric_epoch=-1 ; epoch_loss_values=[] ; metric_values=[]
for epoch in range(epochs):
  model.train()
  epochloss = 0 ; step = 0
  for batch_data in train_loader:
      step += 1
      #inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
      inputs, labels = batch_data[0], batch_data[1]
      optimizer.zero_grad()
      outputs = model(inputs)
      loss = loss_function(outputs, labels)
      loss.backward()
      optimizer.step()
      epochloss += loss.item()
      epoch_len = len(train_ds) // train_loader.batch_size
  epochloss/= step
  epoch_loss_values.append(epochloss)
  print(f'epoch {epoch +1} avg loss {epochloss:.4f}')

  if (epoch +1) %val_interval ==0:
    model.eval()
    with torch.no_grad():
      y_pred = torch.tensor([], dtype=torch.float32) #, device=device)
      y = torch.tensor([], dtype=torch.long) #, device=device)
      for val_data in val_loader:
          val_images, val_labels = (
              #val_data[0].to(device), val_data[1].to(device))
              val_data[0], val_data[1])
          y_pred = torch.cat([y_pred, model(val_images)], dim=0)
          y = torch.cat([y, val_labels], dim=0)
      y_onehot = [y_trans(i) for i in decollate_batch(y, detach= False)]
      y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)]
      auc(y_pred_act, y_onehot)
      result = auc.aggregate()
      auc.reset()
      del y_pred_act, y_onehot
      metric_values.append(result)
      acc_value = torch.eq(y_pred.argmax(dim=1), y)
      acc_metric = acc_value.sum().item() / len(acc_value)
      if result > best_metric:
        best_metric = result
        best_metric_epoch = epoch +1
        torch.save(model.state_dict(),
                   os.path.join(rootdir, 'best_metric_model.pth'))

print(f'train completed, best metric: {best_metric:.4f} at epoch {best_metric_epoch}')

In [None]:
''' Model evaluation on test Dataset '''
model.load_state_dict(torch.load(os.path.join(rootdir, 'best_metric_model.pth')))
model.eval()
y_true, y_pred= [], []
with torch.no_grad():
  for test_data in test_loader:
    test_images, test_labels = (
        #test_data[0].to(device), test_data[0].to(device))
        test_data[0], test_data[0])
    pred = model(test_images).argmax(dim=1)
    for i in range(len(pred)):
        y_true.append(test_labels[i].item())
        y_pred.append(pred[i].item())
classification_report(y_true, y_pred, target_names = class_names)

**MedNIST with DenseNet-121 and Supervised Training workflow**

In [None]:
from monai.transforms import (LoadImageD, EnsureChannelFirstD, ScaleIntensityD)
transform = Compose([
	LoadImageD(keys="image", image_only=True),
	EnsureChannelFirstD(keys="image"),
	ScaleIntensityD(keys="image"),])

In [None]:
from monai.apps import MedNISTDataset
dataset = MedNISTDataset(
    root_dir=rootdir, transform=transform, section='training', download=True)

In [None]:
''' Network and Supervisor Training '''
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
max_epochs =5
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

trainer = SupervisedTrainer(
    device = torch.device('cuda:0'),
    max_epochs = max_epochs,
    train_data_loader = DataLoader(dataset, batch_size=512, shuffle=True),
    network = model,
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5),
    loss_function = torch.nn.CrossEntropyLoss(),
    inferer = SimpleInferer(),
    train_handlers = StatsHandler()
)

In [None]:
trainer.run()

In [None]:
from pathlib import Path
dataset_dir = Path(rootdir, 'MedNIST')
class_names = sorted(f'{x.name}' for x in dataset_dir.iterdir() if x.is_dir)
testdata = MedNISTDataset(root_dir=rootdir, transform=transform,
      section="test", download=False, runtime_cache=True)
class_names, next(iter(testdata))

(['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT', 'README.md'],
 {'image': metatensor([[[0.1250, 0.1250, 0.1250,  ..., 0.1193, 0.1250, 0.1250],
           [0.1250, 0.1250, 0.1250,  ..., 0.1136, 0.1250, 0.1307],
           [0.1250, 0.1250, 0.1250,  ..., 0.1080, 0.1193, 0.1364],
           ...,
           [0.1250, 0.1250, 0.1250,  ..., 0.1875, 0.1250, 0.1136],
           [0.1250, 0.1250, 0.1250,  ..., 0.1477, 0.1250, 0.1250],
           [0.1250, 0.1250, 0.1250,  ..., 0.1193, 0.1250, 0.1307]]]),
  'label': 0,
  'class_name': 'AbdomenCT'})

In [None]:
max_items = 10
with eval_mode(model):
  for item in DataLoader(testdata, batch_size=1,num_workers=0):
    prob = np.array(model(item['image'].to(device)).detach().to('cpu'))[0]
    pred = class_names[prob.argmax()]
    gt = item['class_name'][0]
    print(f'class prediction is {pred}. ground-truth: {gt}')
    max_items -= 1
    if max_items == 0:
      break

class prediction is AbdomenCT. ground-truth: AbdomenCT
class prediction is BreastMRI. ground-truth: BreastMRI
class prediction is ChestCT. ground-truth: ChestCT
class prediction is CXR. ground-truth: CXR
class prediction is Hand. ground-truth: Hand
class prediction is HeadCT. ground-truth: HeadCT
class prediction is HeadCT. ground-truth: HeadCT
class prediction is CXR. ground-truth: CXR
class prediction is ChestCT. ground-truth: ChestCT
class prediction is BreastMRI. ground-truth: BreastMRI
