In [None]:
# !pip install medmnist


In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator

In [2]:
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")


MedMNIST v2.1.0 @ https://github.com/MedMNIST/MedMNIST/


In [3]:
data_flag = 'pathmnist'
# data_flag = 'breastmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [4]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

root="/home/uz1/DATA!/medmnist"
#if not there create it 
if not os.path.exists(root):
    os.mkdir(root)

# # load the data
# train_dataset = DataClass(split='train', transform=data_transform, download=download,root=root)
# test_dataset = DataClass(split='test', transform=data_transform, download=download,root=root)

# pil_dataset = DataClass(split='train', download=download,root=root)

# # encapsulate data into dataloader form
# train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
# test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

In [102]:
# combine multiple datasets into one Class 
class ConcatDataset(medmnist.dataset.MedMNIST2D):
    def __init__(self, *datasets):
        self.datasets = datasets
        #create e dictionary mapping between the dataset index in datasets and the class index in the combined dataset
        class_index = {}
        for i,d in enumerate(datasets):
            for j in range(len(d.info['label'])):
                class_index[i,j] = sum(len(d.info['label']) for d in datasets[:i]) + j
        self.class_index = class_index

    def __getitem__(self, i):
        #based on i and total number of samples in all datasets, determine which dataset to get the sample from
        for y,d in enumerate(self.datasets):
            # print("looking in dataset",d," for sample",i)
            if i < len(d):
                x,z = d[i]
                # print(z)
                if len(z) > 1:
                    z =  np.array(0) if sum(z) == 0 else np.array(1)
                # if image in index 1 has 1 channel, repeat it 3 times then reutrn it
                if x.shape[0] == 1:
                    return x.repeat(3,1,1), self.class_index[(y,int(z))]
                return x, self.class_index[(y,int(z))]
            i -= len(d)
        raise IndexError('index out of range')

    def __len__(self):
        #sum of all the lengths of the datasets
        return sum(len(d) for d in self.datasets)

In [103]:
from medmnist.dataset import PathMNIST, BreastMNIST,OCTMNIST,ChestMNIST,PneumoniaMNIST,DermaMNIST,RetinaMNIST,BloodMNIST

# load the datasets
pathmnist = PathMNIST(split='train', transform=data_transform, download=download,root=root)
breastmnist = BreastMNIST(split='train', transform=data_transform, download=download,root=root)
octmnist = OCTMNIST(split='train', transform=data_transform, download=download,root=root)
chestmnist = ChestMNIST(split='train', transform=data_transform, download=download,root=root)
pneumoniamnist = PneumoniaMNIST(split='train', transform=data_transform, download=download,root=root)
dermamnist = DermaMNIST(split='train', transform=data_transform, download=download,root=root)
retinamnist = RetinaMNIST(split='train', transform=data_transform, download=download,root=root)
bloodmnist = BloodMNIST(split='train', transform=data_transform, download=download,root=root)
 # combine the datasets
combined_dataset = ConcatDataset(pathmnist,breastmnist,octmnist,chestmnist,pneumoniamnist,dermamnist,retinamnist,bloodmnist)

#print lenght of the combined dataset
print(len(combined_dataset))

Using downloaded and verified file: /home/uz1/DATA!/medmnist/pathmnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/breastmnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/octmnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/chestmnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/pneumoniamnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/dermamnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/retinamnist.npz
Using downloaded and verified file: /home/uz1/DATA!/medmnist/bloodmnist.npz
291241


In [107]:
# for each d in datasets print number of calsses 
print("number of classes in each dataset")
for d in combined_dataset.datasets:
    print(len(d.info['label']))
# sum of all number of classes in d of datasets
print("total number of classes in combined dataset")
sum(len(d.info['label']) for d in combined_dataset.datasets)



number of classes in each dataset
9
2
4
14
2
7
5
8
total number of classes in combined dataset


51