In [1]:
import pandas as pd
import numpy as np
import os
from PIL import Image
from collections import Counter
from functools import reduce
import copy

In [2]:
def load_dataset(csv_path, image_dir):
    df = pd.read_csv(csv_path)
    images = []
    possible_extensions = ['.tif', '.png']
    
    for _, row in df.iterrows():
        image_path = None
        for ext in possible_extensions:
            temp_path = os.path.join(image_dir, f"{row['ID']}{ext}")
            if os.path.exists(temp_path):
                image_path = temp_path
                break
        if image_path is None:
            raise FileNotFoundError(f"No image found for ID {row['ID']} with extensions {possible_extensions}")
        
        image = Image.open(image_path)
        image = np.array(image)
        labels = row.drop('ID').to_numpy()
        images.append({'features': image, 'labels': labels, 'ID': row['ID']})
    
    return images, df.columns[1:]

def calculate_ir_lbl(D, label_index):
    label_counts = np.sum([instance['labels'] for instance in D], axis=0)
    label_count = label_counts[label_index]
    max_label_count = np.max(label_counts)
    return max_label_count / label_count if label_count != 0 else float('inf')

def calculate_scu_mble_ins(instance, ir_lbl):
    instance_labels = np.where(instance['labels'] == 1)[0]
    k = len(instance_labels)
    if k == 0:
        return 0
    product_ir_lbl = reduce(lambda x, y: x * y, (ir_lbl[label] for label in instance_labels), 1)
    mean_ir_lbl = np.mean([ir_lbl[label] for label in instance_labels]) if instance_labels.size > 0 else 1
    scumble_ins = 1 - (1 / mean_ir_lbl) * (product_ir_lbl ** (1 / k))
    return scumble_ins

def calculate_scu_mble(D, ir_lbl):
    scumble_ins = [calculate_scu_mble_ins(instance, ir_lbl) for instance in D]
    scumble = np.mean(scumble_ins)
    return scumble, scumble_ins

def remedial(D, labels):
    # Calculate imbalance levels
    ir_lbl = [calculate_ir_lbl(D, i) for i in range(len(labels))]
    ir_mean = np.mean(ir_lbl)
    
    # Calculate SCUMBLE
    scumble, scumble_ins = calculate_scu_mble(D, ir_lbl)
    
    new_id_counter = 0
    new_instances = []
    for i in range(len(D)):
        if scumble_ins[i] > scumble:
            instance = D[i]
            clone_instance = copy.deepcopy(instance)
            clone_instance['ID'] = f"DA_{new_id_counter}"
            
            # Maintain minority labels
            instance['labels'][[label for label in range(len(labels)) if ir_lbl[label] <= ir_mean]] = 0
            # Maintain majority labels
            clone_instance['labels'][[label for label in range(len(labels)) if ir_lbl[label] > ir_mean]] = 0
            
            new_instances.append(clone_instance)
            new_id_counter += 1
    
    D.extend(new_instances)
    print(f"added {new_id_counter+1} samples")
    return D

csv_path = '../data/fundus/MuReD/train_data.csv'
image_dir = '../data/fundus/MuReD/images/images/'
dataset, labels = load_dataset(csv_path, image_dir)
preprocessed_dataset = remedial(dataset, labels)

added 294 samples


In [4]:
def save_preprocessed_dataset(D, labels, image_dir, output_csv_path):
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    label_data = []
    image_names = []
    for instance in D:
        if 'DA_' in instance['ID']:
            image_path = os.path.join(image_dir, f"{instance['ID']}.png")
            image = Image.fromarray(instance['features'])
            image.save(image_path)
        else:
            possible_extensions = ['.tif', '.png']
            for ext in possible_extensions:
                temp_path = os.path.join(image_dir, f"{instance['ID']}{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break
            
        label_data.append(instance['labels'])
        image_names.append(instance['ID'])

    label_df = pd.DataFrame(label_data, columns=labels)
    label_df.insert(0, 'ID', image_names)
    label_df.to_csv(os.path.join(output_csv_path, 'remedial_train_data.csv'), index=False)
    
    counts = label_df.sum(axis=0)
    counts.to_dict()
    print(counts)


output_dir = '../data/fundus/MuReD/images/remedial/'
output_csv_path = '../data/fundus/MuReD/'
save_preprocessed_dataset(preprocessed_dataset, labels, output_dir, output_csv_path)

ID        aria_c_25_1aria_c_7_2aria_c_38_2aria_c_2_8aria...
DR                                                      396
NORMAL                                                  395
MH                                                      135
ODC                                                     211
TSLN                                                    125
ARMD                                                    126
DN                                                      130
MYA                                                      71
BRVO                                                     63
ODP                                                      50
CRVO                                                     44
CNV                                                      48
RS                                                       47
ODE                                                      46
LS                                                       37
CSR                                     