In [None]:
import pandas as pd
import numpy as np
import os
import re
from PIL import Image
import h5py
from sklearn.model_selection import train_test_split
import argparse

In [None]:
import copy
import logging
import torch
import torch.nn as nn
from torchvision import *
import random
import sys
import time
import cv2

In [None]:
csv_path = '../../../../../Explainable-NN/Data/train_fold_1.csv'

In [None]:
img_size = 512

output_path = 'train.h5'
img_dir =  r'ISIC2018_Task1-2_Training_Input/'
mask_dir = 'ISIC2018_Task2_Training_GroundTruth_v3/'

IMG_FILE_REGEX = r'ISIC_(\d+).jpg'
MASK_FILE_REGEX = r'ISIC_(\d+)_attribute_(.*).png'

ISIC2019_FILE_REGEX = r'ISIC_(\d+).jpg'

ATTR_TO_INDEX = {
    'globules' : 0,
    'milia_like_cyst' : 1,
    'negative_network' : 2,
    'pigment_network' : 3,
    'streaks' : 4,
}

In [None]:
Back = '../../../../../594A/Dataset/ISIC challenge 2018 Task 1-2/'

img_dir =  Back + img_dir
mask_dir = Back + mask_dir

#  ISIC2018 train

In [None]:
df_ISIC2018 = pd.DataFrame(columns=['image', 'globules', 'milia_like_cyst', 
                                    'negative_network', 'pigment_network', 'streaks'])

In [None]:
df_ISIC2018.head()

In [None]:
for img_file in os.listdir( img_dir ):
    if not img_file.endswith('.jpg'): continue
    img_id = str( re.search( IMG_FILE_REGEX, img_file ).group(1) )
    img_name = 'ISIC_' + img_id
    df_append = pd.DataFrame.from_records([{'image': img_name, 'globules': 2, 'milia_like_cyst': 2, 
                                      'negative_network': 2, 'pigment_network': 2, 'streaks': 2, 'pigment_network_pred': 2}])
    
    df_ISIC2018 = pd.concat([df_ISIC2018, df_append], ignore_index=True)

In [None]:
df_ISIC2018

## 2. save binary input 

In [None]:
'''
for mask_file in os.listdir( mask_dir ):
    if not mask_file.endswith('.png'): continue
    mask = Image.open( os.path.join( mask_dir, mask_file ) )
    mask = mask.resize( (img_size, img_size) )
    mask = np.array( mask )
    
    assert mask.shape == (img_size, img_size)
    assert mask.max() <= 255
    
    # decide feature exist or not
    if mask.max() > 0:
        binary_label = 1
    else:
        binary_label = 0
    
    m = re.search(MASK_FILE_REGEX, mask_file)
    
    file_idx = df_ISIC2018[df_ISIC2018['image'].str.contains(m.group(1))].index[0]
    
    if m.group(2) ==  'globules':
        df_ISIC2018.iloc[file_idx, df_ISIC2018.columns.get_loc('globules')] = binary_label
    elif m.group(2) ==  'milia_like_cyst':
        df_ISIC2018.iloc[file_idx, df_ISIC2018.columns.get_loc('milia_like_cyst')]  = binary_label
    elif m.group(2) ==  'negative_network':
        df_ISIC2018.iloc[file_idx, df_ISIC2018.columns.get_loc('negative_network')]  = binary_label
    elif m.group(2) ==  'pigment_network':
        df_ISIC2018.iloc[file_idx, df_ISIC2018.columns.get_loc('pigment_network')]  = binary_label
    elif m.group(2) ==  'streaks':
        df_ISIC2018.iloc[file_idx, df_ISIC2018.columns.get_loc('streaks')]  = binary_label
'''

In [None]:
#df_ISIC2018.to_csv('df_ISIC2018_val.csv', index=False)
df_ISIC2018 = pd.read_csv(csv_path)

In [None]:
#df = df_ISIC2018.loc[(df_ISIC2018.pigment_network == 0) & (df_ISIC2018.globules == 0)
#                    & (df_ISIC2018.milia_like_cyst == 0) & (df_ISIC2018.negative_network == 1) & (df_ISIC2018.streaks == 0)]

# Train.h5py

In [None]:
with h5py.File(output_path, "w") as f:
    count = len(df_ISIC2018)

    # store imgs to h5 file
    images = f.create_dataset( 'images', (count, 3, img_size, img_size), dtype=np.uint8 )
    image_ids = f.create_dataset( 'image_ids', (count,), dtype=int )
    img_id_to_h5idx = {}
    i = 0
    for idx in range(len(df_ISIC2018)):
        print( f'Processing image {i+1}/{count}' )
        img_file = df_ISIC2018.iloc[idx].image + '.jpg'
        img = Image.open( os.path.join( img_dir, img_file ) )
        img = img.resize( (img_size, img_size) )
        img = np.array( img )
        img = img.transpose(2, 0, 1)
        #(3,512,1024)
        assert img.shape == (3, img_size, img_size)
        assert img.max() <= 255
        images[i] = img
        img_id = int( re.search( IMG_FILE_REGEX, img_file ).group(1) )
        image_ids[i] = img_id
        img_id_to_h5idx[ img_id ] = i
        i += 1
        
    # store masks to h5 file
    i = 0
    masks = f.create_dataset('masks', (count, 5, img_size, img_size), dtype=np.uint8)
    attr_labels = f.create_dataset('labels', (count, 5), dtype=np.uint8)
    for mask_file in os.listdir( mask_dir ):
        if not mask_file.endswith('.png'): continue
        print( f'Processing mask {i+1}/{5*count}' )
        mask = Image.open( os.path.join( mask_dir, mask_file ) )
        mask = mask.resize( (img_size, img_size) )
        mask = np.array( mask )
        assert mask.shape == (img_size, img_size)
        assert mask.max() <= 255
        m = re.search(MASK_FILE_REGEX, mask_file)
        img_id, attr_id = int(m.group(1)), ATTR_TO_INDEX[ m.group(2) ]
        
        
        h5idx = img_id_to_h5idx[ img_id ]
        masks[ h5idx, attr_id ] = mask 
        if mask.max() > 0:
            attr_labels[h5idx, attr_id] = 1
        i += 1      