# Apply SMOTE algorithm to the training set to fix class imbalance issues
Note: Run 'Dataset Construction.ipynb' and 'Format Dicom Datasets.ipynb' before proceeding

Note: This code requires approximately 24GiB of ram to complete according to htop

In [1]:
import os
import imblearn
import tensorflow.keras as keras
import pandas as pd
import PIL
import cv2
from PIL import Image, ImageOps
from collections import Counter
from random import shuffle
import numpy as np

In [2]:
train_dir = './dataset/train'
meta_path = os.path.join(train_dir, 'metadata.csv')

os.makedirs(os.path.join(train_dir,'sampled','NORMAL'), exist_ok=True)
os.makedirs(os.path.join(train_dir,'sampled','PNEUMONIA'), exist_ok=True)
os.makedirs(os.path.join(train_dir,'sampled','COVID-19'), exist_ok=True)

In [3]:
target_samples = 5000
batch_size = 500
idx = 0

csv = pd.read_csv(meta_path)
rus = imblearn.under_sampling.RandomUnderSampler(random_state=42, sampling_strategy={'NORMAL': target_samples, 'PNEUMONIA': target_samples})
smote = imblearn.over_sampling.SMOTE(random_state=42, k_neighbors=5)

images = {'NORMAL':[], 'PNEUMONIA':[], 'COVID-19':[]}
labels = {'NORMAL':[], 'PNEUMONIA':[], 'COVID-19':[]}
smoted_images = {'NORMAL':[], 'PNEUMONIA':[], 'COVID-19':[]}

for index, series in csv.iterrows():
    # Load all images into Dictionary
    images[series['finding']].append(cv2.imread(series['imagename'], cv2.IMREAD_GRAYSCALE).flatten())
    labels[series['finding']].append(series['finding'])
    
shuffle(images['NORMAL'])
shuffle(images['PNEUMONIA'])
shuffle(images['COVID-19'])

# Preemptively undersample the normal class to the desired 5000 images
print(len(images['NORMAL']), len(images['PNEUMONIA']))
undersample, undersample_labels = rus.fit_resample(images['NORMAL'] + images['PNEUMONIA'], labels['NORMAL'] + labels['PNEUMONIA'])

images['NORMAL'] = [sample for sample, label in zip(undersample, undersample_labels) if label == 'NORMAL']
images['PNEUMONIA'] = [sample for sample, label in zip(undersample, undersample_labels) if label == 'PNEUMONIA']

total_samples = len(images['NORMAL']) + len(images['PNEUMONIA']) + len(images['COVID-19'])
class_imbalance = [len(images['NORMAL']), len(images['PNEUMONIA']), len(images['COVID-19'])]
class_imbalance[:] = [sample / total_samples for sample in class_imbalance]

# SMOTE per batch
while True:
    print(idx * batch_size)
    print('NORMAL: {}, PNEUMONIA: {}, COVID-19: {}'.format(int((idx + 1) * batch_size * class_imbalance[0]), int((idx + 1) * batch_size * class_imbalance[1]) , int((idx + 1) * batch_size * class_imbalance[2])))
    batch, batch_labels = smote.fit_resample(
        images['NORMAL'][int(idx * batch_size * class_imbalance[0]):
                         int((idx + 1) * batch_size * class_imbalance[0]) - 1] +
        images['PNEUMONIA'][int(idx * batch_size * class_imbalance[1]):
                         int((idx + 1) * batch_size * class_imbalance[1]) - 1] +
        images['COVID-19'][int(idx * batch_size * class_imbalance[1]):
                         int((idx + 1) * batch_size * class_imbalance[1]) - 1],
        labels['NORMAL'][int(idx * batch_size * class_imbalance[0]):
                         int((idx + 1) * batch_size * class_imbalance[0]) - 1] +
        labels['PNEUMONIA'][int(idx * batch_size * class_imbalance[1]):
                         int((idx + 1) * batch_size * class_imbalance[1]) - 1] +
        labels['COVID-19'][int(idx * batch_size * class_imbalance[1]):
                         int((idx + 1) * batch_size * class_imbalance[1]) - 1])
    for image, label in zip(batch, batch_labels):
        smoted_images[label].append(image)
    
    if idx * batch_size < len(images['NORMAL']):
        break;
        
    idx += 1
    
batch, batch_labels = smote.fit_resample(
    images['NORMAL'][int(idx * batch_size * class_imbalance[0]):len(images['NORMAL']) - 1] +
    images['PNEUMONIA'][int(idx * batch_size * class_imbalance[1]):len(images['PNEUMONIA']) - 1] +
    images['COVID-19'][int(idx * batch_size * class_imbalance[2]):len(images['COVID-19']) - 1],
    labels['NORMAL'][int(idx * batch_size * class_imbalance[0]):len(images['NORMAL']) - 1] +
    labels['PNEUMONIA'][int(idx * batch_size * class_imbalance[1]):len(images['PNEUMONIA']) - 1] +
    labels['COVID-19'][int(idx * batch_size * class_imbalance[2]):len(images['COVID-19']) - 1])
    
for image, label in zip(batch, batch_labels):
    smoted_images[label].append(image)

for label in ['NORMAL', 'PNEUMONIA', 'COVID-19']:
    idx = 0
    for image in smoted_images[label]:
        image = np.reshape(image,(300, 300)).astype(np.uint8)
        cv2.imwrite(os.path.join(train_dir, 'sampled', label, '{}.png'.format(idx)), image)
        idx += 1

7146 5226
0
NORMAL: 211, PNEUMONIA: 211, COVID-19: 77
