In [1]:
import pandas as pd
import numpy as np
import os
import random
from PIL import Image
from collections import Counter

In [2]:
def load_dataset(csv_path, image_dir, remedial_image_dir=None):
    df = pd.read_csv(csv_path)
    images = []
    possible_extensions = ['.tif', '.png']
    
    for _, row in df.iterrows():
        image_path = None
        if row['ID'].startswith("DA_") and remedial_image_dir is not None:
            image_path = os.path.join(remedial_image_dir, f"{row['ID']}.png")
        else:
            for ext in possible_extensions:
                temp_path = os.path.join(image_dir, f"{row['ID']}{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break
            if image_path is None:
                raise FileNotFoundError(f"No image found for ID {row['ID']} with extensions {possible_extensions}")
        
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
        labels = row.drop('ID').to_numpy()
        images.append({'features': image, 'labels': labels, 'id': row['ID']})
    
    return images, df.columns[1:]

def calculate_ir_per_label(D, label_index):
    label_counts = np.sum([instance['labels'] for instance in D], axis=0)
    label_count = label_counts[label_index]
    max_label_count = np.max(label_counts)
    return max_label_count / label_count if label_count != 0 else float('inf')

def calculate_mean_ir(D, labels):
    irs = [calculate_ir_per_label(D, i) for i in range(len(labels))]
    return np.mean(irs)

def clone_sample(sample, new_id):
    clone = {'features': sample['features'].copy(), 'labels': sample['labels'].copy(), 'id': new_id}
    return clone

def ml_ros(D, P):
    print(f"original dataset size: {len(D)}")
    samples_to_clone = int(len(D) * P / 100)
    labels = range(len(D[0]['labels']))
    mean_ir = calculate_mean_ir(D, labels)
    
    minority_bags = {label: [] for label in labels}
    for label in labels:
        ir_label = calculate_ir_per_label(D, label)
        if ir_label > mean_ir:
            for instance in D:
                if instance['labels'][label] == 1:
                    minority_bags[label].append(instance)
    
    new_id_counter = 0
    while samples_to_clone > 0:
        if not any(minority_bags.values()):
            break
        for label, bag in list(minority_bags.items()):
            if not bag:
                continue
            sample_index = random.randint(0, len(bag) - 1)
            sample = bag[sample_index]
            clone = clone_sample(sample, f'DA_DA_{new_id_counter}')
            D.append(clone)
            new_id_counter += 1
            samples_to_clone -= 1
            if calculate_ir_per_label(D, label) <= mean_ir:
                minority_bags.pop(label)
            if samples_to_clone <= 0:
                break
    
    print(f"added {new_id_counter+1} samples")
    return D


csv_path = '../data/fundus/MuReD/remedial_train_data.csv'
image_dir = '../data/fundus/MuReD/images/images/'
remedial_image_dir = '../data/fundus/MuReD/images/remedial/'
dataset, labels = load_dataset(csv_path, image_dir, remedial_image_dir=remedial_image_dir)
preprocessed_dataset = ml_ros(dataset, P=40)

original dataset size: 2057
added 162 samples


In [3]:
def save_preprocessed_dataset(D, labels, image_dir, output_csv_path):
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    label_data = []
    image_names = []
    for instance in D:
        if 'DA_' in instance['id']:
            image_path = os.path.join(image_dir, f"{instance['id']}.png")
            image = Image.fromarray(instance['features'])
            image.save(image_path)
        else:
            possible_extensions = ['.tif', '.png']
            for ext in possible_extensions:
                temp_path = os.path.join(image_dir, f"{instance['id']}{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break
            
        label_data.append(instance['labels'])
        image_names.append(instance['id'])

    label_df = pd.DataFrame(label_data, columns=labels)
    label_df.insert(0, 'ID', image_names)
    label_df.to_csv(os.path.join(output_csv_path, 'mlros_remedial_train_data.csv'), index=False)
    
    counts = label_df.sum(axis=0)
    counts.to_dict()
    print(counts)
    
    
output_dir = '../data/fundus/MuReD/images/mlros_remedial/'
output_csv_path = '../data/fundus/MuReD/'
save_preprocessed_dataset(preprocessed_dataset, labels, output_dir, output_csv_path)

ID        aria_c_25_1aria_c_7_2aria_c_38_2aria_c_2_8aria...
DR                                                      396
NORMAL                                                  395
MH                                                      135
ODC                                                     211
TSLN                                                    125
ARMD                                                    126
DN                                                      130
MYA                                                      71
BRVO                                                     63
ODP                                                      62
CRVO                                                     58
CNV                                                      61
RS                                                       58
ODE                                                      56
LS                                                       56
CSR                                     