In [1]:
import pandas as pd
import numpy as np
import os
import random
from PIL import Image
from collections import Counter
from functools import reduce
import copy
from sklearn.neighbors import NearestNeighbors

In [2]:
def load_dataset(csv_path, image_dir):
    df = pd.read_csv(csv_path)
    images = []
    possible_extensions = ['.tif', '.png']
    
    for _, row in df.iterrows():
        image_path = None
        for ext in possible_extensions:
            temp_path = os.path.join(image_dir, f"{row['ID']}{ext}")
            if os.path.exists(temp_path):
                image_path = temp_path
                break
        if image_path is None:
            raise FileNotFoundError(f"No image found for ID {row['ID']} with extensions {possible_extensions}")
        
        image = Image.open(image_path).convert('RGB')
        
        image = image.resize((384,384))  # resize the image to 384x384
        
        image = np.array(image)
        labels = row.drop('ID').to_numpy()
        images.append({'features': image, 'labels': labels, 'id': row['ID']})
    
    return images, df.columns[1:]


csv_path = '../data/fundus/MuReD/train_data.csv'
image_dir = '../data/fundus/MuReD/images/images/'
dataset, labels = load_dataset(csv_path, image_dir)

In [3]:
def clone_sample(sample, new_id):
    clone = {'features': sample['features'].copy(), 'labels': sample['labels'].copy(), 'id': new_id}
    return clone

def calculate_ir_lbl(D, label_index):
    label_counts = np.sum([instance['labels'] for instance in D], axis=0)
    label_count = label_counts[label_index]
    max_label_count = np.max(label_counts)
    return max_label_count / label_count if label_count != 0 else float('inf')

def calculate_scu_mble_ins(instance, ir_lbl):
    instance_labels = np.where(instance['labels'] == 1)[0]
    k = len(instance_labels)
    if k == 0:
        return 0
    product_ir_lbl = reduce(lambda x, y: x * y, (ir_lbl[label] for label in instance_labels), 1)
    mean_ir_lbl = np.mean([ir_lbl[label] for label in instance_labels]) if instance_labels.size > 0 else 1
    scumble_ins = 1 - (1 / mean_ir_lbl) * (product_ir_lbl ** (1 / k))
    return scumble_ins

def calculate_scu_mble(D, ir_lbl):
    scumble_ins = [calculate_scu_mble_ins(instance, ir_lbl) for instance in D]
    scumble = np.mean(scumble_ins)
    return scumble, scumble_ins

In [4]:
def nearest_neighbour(X):
    nbs=NearestNeighbors(n_neighbors=3,metric='euclidean',algorithm='kd_tree').fit(X)
    _,indices= nbs.kneighbors(X)
    return indices

def smote(ref, nn_indices, imbalance_images):
    neighbour = random.choice(nn_indices[ref,1:])
    ratio = random.random()
    gap = imbalance_images[ref,:] - imbalance_images[neighbour,:]
    new_x = np.array(imbalance_images[ref,:] + ratio * gap)
    return new_x.reshape(384,384,3).astype('uint8')

In [5]:
def my_remedial(ori_D, P, threshold=0.5):
    D = copy.deepcopy(ori_D)
    samples_to_clone = int(len(D) * P)
    # print(f"cloning {samples_to_clone} samples")
    
    # Calculate imbalance levels
    ir_lbl = [calculate_ir_lbl(D, i) for i in range(len(labels))]
    ir_mean = np.mean(ir_lbl)
    print(f"ir_lbl: {ir_lbl}")
    print(f"ir_mean: {ir_mean}")
    # Calculate SCUMBLE
    scumble, scumble_ins = calculate_scu_mble(D, ir_lbl)
    print(f"scumble: {scumble}")
    
    imbalance_index = []
    # imbalance_images = []
    for i in range(len(D)):
        if scumble_ins[i] > scumble:
            imbalance_index.append(i)
            # imbalance_images.append(D[i]['features'].flatten())
    
    imbalance_images = np.array(imbalance_images)
    # indices = nearest_neighbour(imbalance_images)
    
    print(f"imbalance index: {len(imbalance_index)}")
    new_id_counter = 0
    new_instances = []
    new_labels_counter = Counter()
    while samples_to_clone > 0:
        samples = random.sample(imbalance_index, min(samples_to_clone, len(imbalance_index)))
        for sample_index in samples:
            # smote
            # ref = imbalance_index.index(sample_index)
            # new_x = smote(ref, indices, imbalance_images)
            
            instance = D[sample_index]
            clone_instance = copy.deepcopy(instance)
            clone_instance['id'] = f"DA_{new_id_counter}"
            # clone_instance['features'] = new_x
            
            # Add minority labels
            clone_instance['labels'][[label for label in range(len(labels)) if ir_lbl[label] <= ir_mean]] = 0
            
            new_instances.append(clone_instance)
            new_labels_counter.update(np.where(clone_instance['labels'] == 1)[0])
            new_id_counter += 1
            samples_to_clone -= 1
            if samples_to_clone <= 0:
                break
            
    D.extend(new_instances)
    print(f"added {new_id_counter+1} samples")
    print(new_labels_counter)
    return D


preprocessed_dataset = my_remedial(dataset, P=0.4)#, threshold=0.2)

ir_lbl: [1.0, 1.0025316455696203, 2.933333333333333, 1.8767772511848342, 3.168, 3.142857142857143, 3.046153846153846, 5.577464788732394, 6.285714285714286, 7.92, 9.0, 8.25, 8.425531914893616, 8.608695652173912, 10.702702702702704, 13.655172413793103, 14.142857142857142, 15.23076923076923, 16.5, 1.894736842105263]
ir_mean: 7.1181649096420205
scumble: 0.028608447515591443
imbalance index: 92
added 706 samples
Counter({14: 245, 9: 161, 16: 77, 17: 61, 18: 51, 11: 40, 13: 30, 10: 24, 15: 22, 12: 8})


In [6]:
def save_preprocessed_dataset(D, labels, image_dir, output_csv_path):
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    label_data = []
    image_names = []
    for instance in D:
        if 'DA_' in instance['id']:
            image_path = os.path.join(image_dir, f"{instance['id']}.png")
            image = Image.fromarray(instance['features'])
            image.save(image_path)
        else:
            possible_extensions = ['.tif', '.png']
            for ext in possible_extensions:
                temp_path = os.path.join(image_dir, f"{instance['id']}{ext}")
                if os.path.exists(temp_path):
                    image_path = temp_path
                    break
            
        label_data.append(instance['labels'])
        image_names.append(instance['id'])

    label_df = pd.DataFrame(label_data, columns=labels)
    label_df.insert(0, 'ID', image_names)
    label_df.to_csv(os.path.join(output_csv_path, 'my_remedial_smote384_train_data.csv'), index=False)
    
    counts = label_df.sum(axis=0)
    counts.to_dict()
    print(counts)


output_dir = '../data/fundus/MuReD/images/my_remedial_smote384/'
output_csv_path = '../data/fundus/MuReD/'
save_preprocessed_dataset(preprocessed_dataset, labels, output_dir, output_csv_path)

ID        aria_c_25_1aria_c_7_2aria_c_38_2aria_c_2_8aria...
DR                                                      396
NORMAL                                                  395
MH                                                      135
ODC                                                     211
TSLN                                                    125
ARMD                                                    126
DN                                                      130
MYA                                                      71
BRVO                                                     63
ODP                                                     211
CRVO                                                     68
CNV                                                      88
RS                                                       55
ODE                                                      76
LS                                                      282
CSR                                     