In [1]:
import albumentations as A
import cv2
import os
from pathlib import Path
import pandas as pd

In [2]:
labels_df = pd.read_csv('../data/fundus/MuReD/train_data.csv')
labels_df['ID'] = labels_df['ID'].astype(str)
class_counts = labels_df.iloc[:, 1:].sum(axis=0)
class_counts

DR        396
NORMAL    395
MH        135
ODC       211
TSLN      125
ARMD      126
DN        130
MYA        71
BRVO       63
ODP        50
CRVO       44
CNV        48
RS         47
ODE        46
LS         37
CSR        29
HTR        28
ASR        26
CRS        24
OTHER     209
dtype: int64

In [3]:
# transform = A.Compose([
#     A.Rotate(limit=45, p=0.5),  # 旋轉，概率為0.5
#     A.HorizontalFlip(p=0.5),    # 水平翻轉，概率為0.5
# ])

transform = A.Compose([
    A.OneOf([
        A.Rotate(limit=45, p=1.0),  # 旋转，当被选择时应用的概率为1.0
        A.HorizontalFlip(p=1.0),    # 水平翻转，当被选择时应用的概率为1.0
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),  # 随机亮度对比度调整
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0)  # 色调饱和度调整
    ], p=1)
])

In [4]:
images_path = Path("../data/fundus/MuReD/images/images") # training set images path
da_images_path = Path("../data/fundus/MuReD/images/ros") # augmented images path

class_add_count = 200
ID_name = 1
while True:
    class_counts = labels_df.iloc[:, 1:].sum(axis=0)
    if all(class_counts >= class_add_count):
        break 
    
    for class_name, count in class_counts.items():
        if count >= class_add_count:
            continue

        images_to_augment = labels_df[labels_df[class_name] == 1]
        img_rows = images_to_augment.sample(n=min(class_add_count-count, count))
        # print(img_rows)
        for index, img_row in img_rows.iterrows():
            img_name = img_row['ID']
            if img_name.startswith('DA'):
                continue
            
            image_path = images_path / f"{img_name}.png" if os.path.exists(images_path / f"{img_name}.png") else images_path / f"{img_name}.tif"
            image = cv2.imread(str(image_path))
            if image is not None:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                
                # transformed = transform(image=image)
                # transformed_image = transformed["image"]

                new_img_name = f"DA_{ID_name}.png"
                ID_name += 1
                new_img_path = da_images_path / new_img_name
                # cv2.imwrite(str(new_img_path), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                
                new_row = img_row.copy()
                new_row['ID'] = new_img_name.replace('.png', '')
                labels_df = labels_df._append(new_row, ignore_index=True)

        break
    # print(class_counts)
    # print("xxxxxxxxxxxxxxxxxxxxxxxx")
            
class_counts

DR        607
NORMAL    395
MH        258
ODC       377
TSLN      261
ARMD      385
DN        270
MYA       208
BRVO      213
ODP       231
CRVO      216
CNV       245
RS        209
ODE       200
LS        200
CSR       200
HTR       225
ASR       200
CRS       200
OTHER     375
dtype: int64

In [5]:
labels_df.to_csv('../data/fundus/MuReD/ros_train_data.csv', index=False)