In [3]:
import shutil
import os, shutil
from skimage.exposure import match_histograms
import numpy as np

from sklearn.model_selection import train_test_split

%run ./variables.ipynb
%run ./utils.ipynb
%run ../utils/data_utils.ipynb
%run ../utils/image_utils.ipynb

N_LIM = 10
norm = False

In [4]:
# Loading reference image for histogram matching and saving ref img
ref = cv2.imread("/mnt/nvme-storage/pfauregi/datasets/atlas/ref_img.png", cv2.IMREAD_GRAYSCALE)
cv2.imwrite(os.path.join(SAVED_MODELS_ROOT, "ref_img.png"), ref)

# Fetching files
taxons_dict = {}
selected_taxons = get_selected_taxons(SELECTED_TAXONS)
for path in ATLAS_PATH:
    print(path)
    for taxon in os.listdir(path):
        if taxon in selected_taxons.keys():
        #if taxon in ["AUGA"]:
            dir_path = os.path.join(path, taxon)
            files = [f for f in os.listdir(dir_path) if isfile(join(dir_path, f))]
            for file in files:
                split = file.split(".")
                if (len(split)>1 and split[1]=="png"):
                    source_file = os.path.join(dir_path, file)
                    target_file = os.path.join(taxon, file)
                    img_path = os.path.join(dir_path, file)
                    taxons_dict.setdefault(taxon, []).append({"source": source_file, "target": target_file})

# Filtering
X, y = [], []
eliminated_taxons = {}
for taxon in taxons_dict:
    files_tmp = taxons_dict[taxon]
    if len(files_tmp)>=N_LIM:
        X.extend(files_tmp)
        y.extend([taxon]*len(files_tmp))
    else:
        eliminated_taxons.setdefault(taxon, None)
        
print(len(X) ,"images detected belonging to", len(np.unique(y)), "classes found in",len(ATLAS_PATH),"atlas!")
print("Eliminated taxon (unsufficient number of images):", eliminated_taxons.keys())

# Train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
taxons_dict_train = {}
taxons_dict_test = {}

print("Train dataset composed of", len(X_train), "images and", len(np.unique(y_train)), "classes.")
print("Train dataset composed of", len(X_test), "images and", len(np.unique(y_test)), "classes.")

# Building dataset
delete_all_files_in_folder(DATASET_PATH)
save_path = [TRAIN_DATASET_PATH, TEST_DATASET_PATH]
Xs = [X_train, X_test]
Ys = [y_train, y_test]
for k in range(len(save_path)):
    print((k+1),"/",len(save_path))
    path = save_path[k]
    X = Xs[k]
    y = Ys[k]
    for i in range(len(X)):
        taxon = y[i]
        source_file = X[i]["source"]
        target_file = os.path.join(path, X[i]["target"])
        check_dirs(target_file)
        img = cv2.imread(source_file, cv2.IMREAD_GRAYSCALE)
        if norm: img = match_histograms(img, ref, multichannel=False).astype("uint8")
        img = convert_to_square(img, new_size=256)
        cv2.imwrite(target_file, img)
print("Finished !")

/mnt/nvme-storage/pfauregi/datasets/atlas/BRG
/mnt/nvme-storage/pfauregi/datasets/atlas/IDF
/mnt/nvme-storage/pfauregi/datasets/atlas/RA
9661 images detected belonging to 187 classes found in 3 atlas!
Eliminated taxon (unsufficient number of images): dict_keys(['PLEV', 'NMIC', 'SPIN', 'GYAC', 'BPAX', 'TLEV', 'NVIR', 'NAAN', 'NRHY', 'PSCA', 'NVDA', 'NESC'])
Train dataset composed of 7728 images and 187 classes.
Train dataset composed of 1933 images and 187 classes.
1 / 2
2 / 2
Finished !


In [None]:
labels = ["taxon", "total"]
for atlas in ATLAS_PATH:
    labels.append(atlas.split("/")[-1])
dict_array = []
for taxon in sorted (taxons_dict.keys()):
    total = len(taxons_dict[taxon])
    row_dict = {
        "taxon": taxon,
        "total": total
    }
    for path in taxons_dict[taxon]:
        aname = path["source"].split("/")[-3]
        row_dict.setdefault(aname, 0)
        row_dict[aname]+=1  
    dict_array.append(row_dict)
    
f = open('./test.csv', 'w')
with f:
    writer = csv.writer(f)
    writer.writerow(labels)
    for row_dict in dict_array:
        row = []
        for x in labels:
            if x in row_dict:
                row.append(row_dict[x])
            else:
                row.append(0)
        writer.writerow(row)