In [6]:
# !pip install imbalanced-learn

Collecting imbalanced-learn
  Obtaining dependency information for imbalanced-learn from https://files.pythonhosted.org/packages/9d/41/721fec82606242a2072ee909086ff918dfad7d0199a9dfd4928df9c72494/imbalanced_learn-0.13.0-py3-none-any.whl.metadata
  Downloading imbalanced_learn-0.13.0-py3-none-any.whl.metadata (8.8 kB)
Collecting sklearn-compat<1,>=0.1 (from imbalanced-learn)
  Obtaining dependency information for sklearn-compat<1,>=0.1 from https://files.pythonhosted.org/packages/f0/a8/ad69cf130fbd017660cdd64abbef3f28135d9e2e15fe3002e03c5be0ca38/sklearn_compat-0.1.3-py3-none-any.whl.metadata
  Downloading sklearn_compat-0.1.3-py3-none-any.whl.metadata (18 kB)
Downloading imbalanced_learn-0.13.0-py3-none-any.whl (238 kB)
   ---------------------------------------- 0.0/238.4 kB ? eta -:--:--
   ----- ---------------------------------- 30.7/238.4 kB 1.4 MB/s eta 0:00:01
   ---------- ---------------------------- 61.4/238.4 kB 825.8 kB/s eta 0:00:01
   -------------------- -----------------


[notice] A new release of pip is available: 23.2.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [38]:
# !pip install tqdm

Collecting tqdm
  Obtaining dependency information for tqdm from https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl.metadata
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
     ---------------------------------------- 0.0/57.7 kB ? eta -:--:--
     ------- -------------------------------- 10.2/57.7 kB ? eta -:--:--
     -------------------------- ----------- 41.0/57.7 kB 487.6 kB/s eta 0:00:01
     -------------------------------------- 57.7/57.7 kB 431.2 kB/s eta 0:00:00
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
   ---------------------------------------- 0.0/78.5 kB ? eta -:--:--
   ------------------------------- -------- 61.4/78.5 kB 1.7 MB/s eta 0:00:01
   ---------------------------------------- 78.5/78.5 kB 1.5 MB/s eta 0:00:00
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1



[notice] A new release of pip is available: 23.2.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
import os
import pandas as pd
import numpy as np
# from imblearn.over_sampling import RandomOverSampler, SMOTE

In [3]:
import torchvision.transforms as transforms
from PIL import Image
import cv2
from tqdm import tqdm

# INITIAL_IMAGE_DIRECTORY = "one_eye_images_copy"
# DESTINATION_DIRECTORY = "one_eye_images_copy"

def augment_image(image_name, input_path, output_path, transform, n_augments=5):
    image = Image.open(os.path.join(input_path, image_name), 'r')
    augmented_image_names = []
    
    for i in range(n_augments):
        aug_image = transform(image)
        aug_image_name = image_name.split('.')[0] + f"_aug{i}.jpg"
        aug_image_path = os.path.join(output_path, aug_image_name)
        aug_image.save(aug_image_path)
        augmented_image_names.append(aug_image_name)
    
    return augmented_image_names

In [2]:
def augment_dataset(dataset, target_samples, disease_columns, input_path, image_output_path, df_output_path="", transform=None):
    if not transform:
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.2, hue=0),
            transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
            ])
    
    if not os.path.exists(image_output_path):
        os.makedirs(image_output_path)
    
    augmented_data = []
    
    for disease in disease_columns:
        group = dataset[dataset[disease] == 1]
        n_augments = target_samples - len(group)
        augments_per_image = n_augments // len(group)
        if augments_per_image == 0:
            augments_per_image += 1
        augmented = 0
        progress_bar = tqdm(total=n_augments, desc=f"Augmenting {disease}", initial=0)
        for _, row in group.iterrows():
            if augmented == n_augments:
                break
            augmented_image_names = augment_image(row['image_id'], input_path, image_output_path, transform, augments_per_image)
            for aug_image_name in augmented_image_names:
                augmented_data.append({
                    'image_id': aug_image_name,
                    'patient_age': row['patient_age'],
                    'patient_sex': row['patient_sex'],
                    **{col: row[col] for col in disease_columns}
                })
                augmented += 1
                progress_bar.update(1)
        while augmented < n_augments:
            for _, row in group.iterrows():
                if augmented == n_augments:
                    break
                augmented_image_names = augment_image(row['image_id'], input_path, image_output_path, transform, 1)
                for aug_image_name in augmented_image_names:
                    augmented_data.append({
                        'image_id': aug_image_name,
                        'patient_age': row['patient_age'],
                        'patient_sex': row['patient_sex'],
                        **{col: row[col] for col in disease_columns}
                    })
                    augmented += 1
                    progress_bar.update(1)
    augmented_dataframe = pd.DataFrame(augmented_data)
    augmented_dataframe.to_csv(os.path.join(df_output_path, "augmented_df.csv"))
    return augmented_dataframe
    