# Creating a train, validation and test split with even distribution of scans with features of interesting

### Split split focuses on the following features

- 4: sub retinal fluid
- 5: sub retainal hyper reflective material
- 7: fibrovascular PED
- 8: drusen
- 9: poster hyloid membrane detachment
- 10: choroid
- 13: fibrosis

In [9]:
import os
import numpy as np
import pandas as pd
import sklearn
import cv2
from tqdm import tqdm

FEATURES_OI = [4, 5, 7, 8, 9, 10, 13]

WORK_SPACE = "/home/olle/PycharmProjects/LODE/workspace"
LBL_PATH = "feature_segmentation/segmentation/data/train_data/hq_examples/masks"

MASK_DIR = os.path.join(WORK_SPACE, LBL_PATH)

#### load all labels

In [10]:
label_names = os.listdir(MASK_DIR)

labels = {}
for label_name in label_names:
    label_path = os.path.join(MASK_DIR, label_name)
    label = cv2.imread(label_path)
    labels[label_name] = label

#### create dict with record feature distribution

In [22]:
feature_record_dist = {4: [], 5: [], 7: [], 8: [], 9: [], 10: [], 13: []}
for feature in FEATURES_OI:
    print("Loogging feature: ", feature)
    for label_name in tqdm(labels.keys()):
        if feature in labels[label_name]:
            feature_record_dist[feature].append(label_name)

100%|██████████| 525/525 [00:00<00:00, 8258.18it/s]
100%|██████████| 525/525 [00:00<00:00, 10323.53it/s]
100%|██████████| 525/525 [00:00<00:00, 10834.85it/s]
  0%|          | 0/525 [00:00<?, ?it/s]

Loogging feature:  4
Loogging feature:  5
Loogging feature:  7
Loogging feature:  8


100%|██████████| 525/525 [00:00<00:00, 10156.40it/s]
100%|██████████| 525/525 [00:00<00:00, 10042.37it/s]
100%|██████████| 525/525 [00:00<00:00, 10474.29it/s]
100%|██████████| 525/525 [00:00<00:00, 10586.48it/s]

Loogging feature:  9
Loogging feature:  10
Loogging feature:  13





In [27]:
for feature in FEATURES_OI:
    print(f"Number of samples in feature {feature} is {len(feature_record_dist[feature])}")

Number of samples in feature 4 is 98
Number of samples in feature 5 is 60
Number of samples in feature 7 is 176
Number of samples in feature 8 is 114
Number of samples in feature 9 is 126
Number of samples in feature 10 is 356
Number of samples in feature 13 is 80


In [62]:
test_records = []
validation_records = []
record_names = list(labels.keys())

In [63]:
n_srf_test = 0
n_srhm_test = 0
n_drusen_test = 0
n_fibrosis_test = 0
n_phmd_test = 0 
n_not_choroid_test = 0
n_irf_test = 0
n_em_test = 0 
n_fvped_test = 0

test_records = []
print("Number of available records before test data:", len(record_names))
for label_name in record_names:
    lbl = labels[label_name]
    
    if n_fvped_test < 5:
        if 7 in lbl:
            print(7)
            n_fvped_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_em_test < 5:
        if 1 in lbl:
            print(1)
            n_em_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_srf_test < 5:
        if 4 in lbl:
            print(4)
            n_srf_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_srhm_test < 5:
        if 5 in lbl:
            print(5)
            n_srhm_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_drusen_test < 5:
        if 8 in lbl:
            print(8)
            n_drusen_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue

    if n_fibrosis_test < 5:
        if 13 in lbl:
            print(13)
            n_fibrosis_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
            
    if n_phmd_test < 5:
        if 9 in lbl:
            print(9)
            n_phmd_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_not_choroid_test < 5:
        if 10 not in lbl:
            print("not", 10)
            n_not_choroid_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_irf_test < 5:
        if 3 in lbl:
            print(3)
            n_irf_test += 1
            test_records.append(label_name)
            record_names.remove(label_name)
            continue
            
print("Number of available records after test data:", len(record_names))

Number of available records before test data: 525
13
7
8
1
1
8
7
7
1
1
1
9
7
7
8
not 10
4
8
4
8
9
9
9
3
13
9
4
not 10
3
5
not 10
3
not 10
4
4
not 10
13
13
3
5
3
13
5
5
5
Number of available records after test data: 480


In [64]:
n_srf_validation = 0
n_srhm_validation = 0
n_drusen_validation = 0
n_fibrosis_validation = 0
n_phmd_validation = 0 
n_not_choroid_validation = 0
n_irf_validation = 0 
n_em_validation = 0
n_fvped_validation = 0 

validation_records = []
print("Number of available records before test data:", len(record_names))
for label_name in record_names:
    lbl = labels[label_name]
    
    if n_fvped_validation < 5:
        if 7 in lbl:
            print(7)
            n_fvped_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_srf_validation < 5:
        if 4 in lbl:
            print(4)
            n_srf_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_srhm_validation < 5:
        if 5 in lbl:
            print(5)
            n_srhm_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_drusen_validation < 5:
        if 8 in lbl:
            print(8)
            n_drusen_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue

    if n_fibrosis_validation < 5:
        if 13 in lbl:
            print(13)
            n_fibrosis_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
                
    if n_phmd_validation < 5:
        if 9 in lbl:
            print(9)
            n_phmd_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_not_choroid_validation < 5:
        if 10 not in lbl:
            print("not", 10)
            n_not_choroid_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
            
    if n_irf_validation < 5:
        if 3 in lbl:
            print(3)
            n_irf_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
    
    if n_em_validation < 5:
        if 1 in lbl:
            print(1)
            n_em_validation += 1
            validation_records.append(label_name)
            record_names.remove(label_name)
            continue
            
            
print("Number of available records after test data:", len(record_names))

Number of available records before test data: 480
7
13
not 10
7
not 10
9
13
8
7
8
13
13
7
7
8
not 10
1
8
1
not 10
8
not 10
1
3
1
4
9
3
13
4
1
4
9
3
4
4
3
5
3
9
9
5
5
5
5
Number of available records after test data: 435


## verify data split distribution

In [65]:
print(f"number of train records: {len(record_names)}, \
      number validation records: {len(validation_records)}, \
      number of test records: {len(validation_records)}")

number of train records: 435,       number validation records: 45,       number of test records: 45


In [69]:
# test so no records overlap
print([record for record in record_names if record in validation_records])
print([record for record in record_names if record in test_records])
print([record for record in validation_records if record in test_records])

[]
[]
[]


In [72]:
save_dir = "feature_segmentation/segmentation/data/train_data/data_split"

test_save_path = os.path.join(WORK_SPACE, save_dir, "test_ids.csv")
train_save_path = os.path.join(WORK_SPACE, save_dir, "train_ids.csv")
validation_save_path = os.path.join(WORK_SPACE, save_dir, "validation_ids.csv")

pd.DataFrame(test_records).to_csv(test_save_path)
pd.DataFrame(validation_records).to_csv(validation_save_path)
pd.DataFrame(record_names).to_csv(train_save_path)

In [75]:
#### visualize data split

import shutil

overview_path = "feature_segmentation/segmentation/data/train_data/hq_examples/overview"

os.makedirs(os.path.join(WORK_SPACE, save_dir, "train_overview"), exist_ok=True)
os.makedirs(os.path.join(WORK_SPACE, save_dir, "test_overview"), exist_ok=True)
os.makedirs(os.path.join(WORK_SPACE, save_dir, "validation_overview"), exist_ok=True)

for record in record_names:
    shutil.copy(os.path.join(WORK_SPACE, overview_path, record), 
               os.path.join(WORK_SPACE, save_dir, "train_overview"))
    

for record in test_records:
    shutil.copy(os.path.join(WORK_SPACE, overview_path, record), 
               os.path.join(WORK_SPACE, save_dir, "test_overview"))
    

for record in validation_records:
    shutil.copy(os.path.join(WORK_SPACE, overview_path, record), 
               os.path.join(WORK_SPACE, save_dir, "validation_overview"))
    