In [None]:
# !pip install imbalanced-learn

In [None]:
# !pip install tqdm

In [11]:
import os
import pandas as pd
import torchvision.transforms as transforms
from PIL import Image
import cv2
from tqdm import tqdm
# from imblearn.over_sampling import RandomOverSampler, SMOTE

In [12]:
# INITIAL_IMAGE_DIRECTORY = "one_eye_images_copy"
# DESTINATION_DIRECTORY = "one_eye_images_copy"

def augment_and_save_image(image_name, input_path, output_path, transform, n_augments=5):
    """
    _summary_

    Args:
        image_name (_type_): input image name
        input_path (_type_): input directory path
        output_path (_type_): ouput directory path
        transform (_type_): transforms that will be applied to images
        n_augments (int, optional): augments per image. Defaults to 5.

    Returns:
        list: list od augmented images names
        
    """
    image = Image.open(os.path.join(input_path, image_name), 'r')
    augmented_image_names = []
    
    for i in range(n_augments):
        aug_image = transform(image)
        aug_image_name = image_name.split('.')[0] + f"_aug{i}.jpg"
        aug_image_path = os.path.join(output_path, aug_image_name)
        aug_image.save(aug_image_path)
        augmented_image_names.append(aug_image_name)
    
    return augmented_image_names

In [32]:
def augment_dataset(dataset, disease_columns, input_path, image_output_path, include_meta=True, transform=None):
    """
    _summary_

    Args:
        dataset (_type_): dataset with images and labels
        disease_columns (_type_): columns with values 1 or 0 for disease presence
        input_path (_type_): input image file path
        image_output_path (_type_): output path for augmented images
        include_meta (bool, optional): include parameters such as age and sex or not. Defaults to True.
        transform (_type_, optional): transformations to the image. Defaults to None.

    Returns:
        _type_: _description_
    """
    if not transform:
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.2, hue=0),
            transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
            ])
    
    if not os.path.exists(image_output_path):
        os.makedirs(image_output_path)
    
    augmented_data = []
    columns = dataset.columns.to_list()
    columns.remove('image_id')
    target_samples = dataset[disease_columns].value_counts().max()
    
    for disease in disease_columns:
        group = dataset[dataset[disease] == 1]
        n_augments = target_samples - len(group)
        # augments_per_image = n_augments // len(group)
        # if augments_per_image == 0:
        #     augments_per_image += 1
        augmented = 0
        progress_bar = tqdm(total=n_augments, desc=f"Augmenting {disease}", initial=0)
        # for _, row in group.iterrows():
        #     if augmented == n_augments:
        #         break
        #     augmented_image_names = augment_and_save_image(row['image_id'], input_path, image_output_path, transform, augments_per_image)
        #     for aug_image_name in augmented_image_names:
        #         if include_meta:
        #             augmented_data.append({
        #                 'image_id': aug_image_name,
        #                 'patient_age': row['patient_age'],
        #                 'patient_sex': row['patient_sex'],
        #                 **{col: row[col] for col in disease_columns}
        #             })
        #         else:
        #             augmented_data.append({
        #                 'image_id': aug_image_name,
        #                 **{col: row[col] for col in disease_columns}
        #             })
        #         augmented += 1
        #         progress_bar.update(1)
        while augmented < n_augments:
            for _, row in group.iterrows():
                if augmented == n_augments:
                    break
                augmented_image_names = augment_and_save_image(row['image_id'], input_path, image_output_path, transform, 1)
                for aug_image_name in augmented_image_names:
                    augmented_data.append({
                        'image_id': aug_image_name,
                        **{col: row[col] for col in columns}
                    })
                    # print(augmented_data)
                    augmented += 1
                    progress_bar.update(1)
    augmented_dataframe = pd.DataFrame(augmented_data)
    return augmented_dataframe
    

In [28]:
def undersampling_images_only(dataset, classes, sample_number=300):
    """
    _summary_

    Args:
        dataset (_type_): pandas dataset
        classes (_type_): list of classes
        sample_number (int, optional): sample number for each class. Defaults to 300.

    Returns:
        pandas.DataFrame(): dataset with sampled data
    
    """
    result_df = pd.DataFrame()
    for disease_class in classes:
        disease_df = dataset[dataset[disease_class] == 1]
        if len(disease_df) > sample_number:
            disease_df = disease_df.sample(sample_number)
        result_df = pd.concat([result_df, disease_df])
    return result_df

In [29]:
dataset = pd.read_csv("odir_data_only_one_eye.csv")
classes = ['diabetic_retinopathy', 'amd', 'hypertensive_retinopathy', 'normal_eye', 'glaucoma', 'cataract']
undersampled_dataset = undersampling_images_only(dataset, classes)

In [30]:
undersampled_dataset

Unnamed: 0,image_id,patient_id,patient_age,patient_sex,exam_eye,diabetic_retinopathy,amd,hypertensive_retinopathy,normal_eye,data_source,glaucoma,cataract
927,757_left.jpg,9282,58.0,2,2,1,0,0,0,ODIR,0,0
4884,4245_right.jpg,12770,56.0,1,1,1,0,0,0,ODIR,0,0
4920,4266_left.jpg,12791,48.0,1,2,1,0,0,0,ODIR,0,0
5304,4479_left.jpg,13004,55.0,1,2,1,0,0,0,ODIR,0,0
4602,4079_right.jpg,12604,54.0,2,1,1,0,0,0,ODIR,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
2134,2248_right.jpg,10773,74.0,2,1,0,0,0,0,ODIR,0,1
2135,2251_left.jpg,10776,70.0,1,2,0,0,0,0,ODIR,0,1
2136,2251_right.jpg,10776,70.0,1,1,0,0,0,0,ODIR,0,1
2137,2262_left.jpg,10787,65.0,2,2,0,0,0,0,ODIR,0,1


In [33]:
augmentation_dataset = augment_dataset(undersampled_dataset, classes, 'one_eye_images_ODIR_only', 'undersampled_and_augmented', include_meta=False)

Augmenting diabetic_retinopathy: 0it [00:00, ?it/s]

Augmenting diabetic_retinopathy: 0it [00:00, ?it/s]




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

Augmenting amd: 100%|██████████| 60/60 [00:38<00:00,  1.58it/s]:00<?, ?it/s]
Augmenting hypertensive_retinopathy: 100%|██████████| 190/190 [02:02<00:00,  1.36it/s]

Augmenting hypertensive_retinopathy: 100%|██████████| 190/190 [02:02<00:00,  1.55it/s]
Augmenting normal_eye: 0it [00:00, ?it/s]0 [00:00<?, ?it/s]
Augmenting glaucoma: 100%|██████████| 40/40 [00:30<00:00,  1.43it/s]

Augmenting glaucoma: 100%|██████████| 40/40 [00:30<00:00,  1.33it/s]


[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

Augmenting catar

In [34]:
augmentation_dataset

Unnamed: 0,image_id,patient_id,patient_age,patient_sex,exam_eye,diabetic_retinopathy,amd,hypertensive_retinopathy,normal_eye,data_source,glaucoma,cataract
0,43_left_aug0.jpg,8568,35.0,1,2,0,1,0,0,ODIR,0,0
1,48_left_aug0.jpg,8573,69.0,2,2,0,1,0,0,ODIR,0,0
2,48_right_aug0.jpg,8573,69.0,2,1,0,1,0,0,ODIR,0,0
3,53_left_aug0.jpg,8578,65.0,2,2,0,1,0,0,ODIR,0,0
4,53_right_aug0.jpg,8578,65.0,2,1,0,1,0,0,ODIR,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
293,112_right_aug0.jpg,8637,57.0,2,1,0,0,0,0,ODIR,0,1
294,119_left_aug0.jpg,8644,59.0,1,2,0,0,0,0,ODIR,0,1
295,188_right_aug0.jpg,8713,55.0,2,1,0,0,0,0,ODIR,0,1
296,218_right_aug0.jpg,8743,36.0,2,1,0,0,0,0,ODIR,0,1


In [35]:
result = pd.concat([undersampled_dataset, augmentation_dataset])
result[classes].value_counts()

diabetic_retinopathy  amd  hypertensive_retinopathy  normal_eye  glaucoma  cataract
0                     0    0                         0           0         1           300
                                                                 1         0           300
                                                     1           0         0           300
                           1                         0           0         0           300
                      1    0                         0           0         0           300
1                     0    0                         0           0         0           300
Name: count, dtype: int64

In [37]:
result.to_csv("augmented_and_original_df_ODIR.csv")