# Data loading 

Plan:
1. Read in excel of patient IDs, days post transplant, rejection grade (excel file -> pandas)
2. Make folders for binary dataset / multiclass dataset
3. Iterate through dataframe patient ids:
    - Piece together image code (PA****[P/E]N****.png)
    - Take rejection grade
    - Locate image file in transplant folder in transplant ecg folder
    - Assign label based on current classification task
    - Copy ecg into label folders (Start with multiclass)
    

## Import

In [None]:
import pandas as pd
import os
import shutil
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import skimage

## Data loading

In [None]:
FileName = '//data//Vandenberg-Lab//Internal//Rui//NIDACT2_ECG_reference .xlsx'
ECG_folder = '//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//original'
print(FileName)
mastersheet = pd.read_excel(FileName)

In [None]:
mastersheet.columns


In [None]:
IDs_Labels = mastersheet.get(["Patient n. ","Days post-transplant ","Overall grade ","Electronic/paper"])
print(IDs_Labels)

In [None]:
dir_name = "Multiclass_HTx_Rejection_Dataset_1(AutoCropped)"

no_rejection_dir = "//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//"+ dir_name + os.sep + "no rejection"
mild_rejection_dir = "//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//"+ dir_name + os.sep + "mild rejection"
moderate_to_severe_rejection_dir = "//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//"+ dir_name + os.sep + "moderate-to-severe rejection"
unidentified_dir = "//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//"+ dir_name + os.sep + "unidentified"


In [None]:
# Make label dataset folders

# os.mkdir("//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//"+ dir_name)

# os.mkdir(no_rejection_dir)

# os.mkdir(mild_rejection_dir)

# os.mkdir(moderate_to_severe_rejection_dir)

# os.mkdir(unidentified_dir)

## AUTO_CROPPING_FUNCTION

In [None]:
def AutoCropping(src,dst):

    horizontal_grid_size = 27  # Assumption of number of major horizontal ticks/gridlines
    grid_cut = 20 # Line number to perform crop/cut

    originalImage = Image.open(src)
    w, h = originalImage.size
    # print('Original')
    # plt.imshow(originalImage)
    # plt.show()

    try:
        red, green, blue, alpha = originalImage.split()
    except:
        red, green, blue = originalImage.split()
    
    grey = originalImage.convert('L')

    yllw = np.asarray(grey)-np.asarray(red)
    prpl = np.asarray(grey)-np.asarray(green)
    brwn = np.asarray(grey)-np.asarray(blue)

    selection = prpl

    # plt.imshow(Image.fromarray(selection), cmap='gray')
    # plt.show()

    proceed = False

    #TDOD: incorporate block_reduce to bin image for cleaner gridline detection
    # downsampled_image = skimage.measure.block_reduce(prpl, (2,2), np.sum)
    # downsampled_image = (downsampled_image-np.min(downsampled_image))/(np.max(downsampled_image)-np.min(downsampled_image))
    # plt.imshow(Image.fromarray(downsampled_image), cmap='gray')
    # plt.show()

    for threshold in np.arange(150,250,10): # np.arange(50,300,10): #
        thresholdedImage =  np.where(selection < threshold, selection, 0.0)
        
        # try:
        vert_sum = np.sum(thresholdedImage, axis=0)
        horz_sum = np.sum(thresholdedImage[:,:200], axis=1)
        # using height and expected number of gridlines to estimate distance between gridlines
        vert_prominence_var = np.max(vert_sum)*0.2
        horz_prominence_var = np.max(horz_sum)*0.2
        vert_gridlines = find_peaks(vert_sum, prominence=vert_prominence_var, distance=(h/horizontal_grid_size)*0.9)#10000)#
        horz_gridlines = find_peaks(horz_sum, prominence=horz_prominence_var, distance=(h/horizontal_grid_size)*0.9)#10000)#
        # except:
        #     continue

        print(len(horz_gridlines[0]))

        if (horizontal_grid_size-3) < len(horz_gridlines[0]) < (horizontal_grid_size+3):
            proceed = True
            break

    if not(proceed):
        print("Did not process sample: " + src)
        return src

    # plt.plot(vert_sum)
    # plt.plot(vert_gridlines[0], vert_sum[vert_gridlines[0]], 'or')
    # plt.show()

    # plt.plot(horz_sum)
    # plt.plot(horz_gridlines[0], horz_sum[horz_gridlines[0]], 'or')
    # plt.show()

    cropImage = originalImage.crop((vert_gridlines[0][0],horz_gridlines[0][0],vert_gridlines[0][-1],horz_gridlines[0][grid_cut])) #13 for double rhytm strip
    # plt.imshow(cropImage)
    # plt.show()
    cropImage.save(dst)

## Rejection labels

In [None]:
# For loop: Loop through the dataframe and copy patient ecgs

errorfilelist = []

for idx in IDs_Labels.index: #np.arange(10): #
    patientID = IDs_Labels.loc[idx]["Patient n. "]
    DPT = IDs_Labels.loc[idx]["Days post-transplant "]
    # (Days Post Transplant) 
    Electronic_paper = IDs_Labels.loc[idx]["Electronic/paper"]
    RejectionGrade = IDs_Labels.loc[idx]["Overall grade "]

    folder_name = "PA" + str("{:04d}".format(patientID)) # {:0d} --> formats the number digits to always have 4 digits 

    image_name = folder_name + Electronic_paper + "N" + str("{:04d}".format(DPT)) 

    print(folder_name + os.sep + image_name + ".png") 
    print(RejectionGrade)


# Locate the ECG image into folders depending on its rejection grades 
    if "Grade 0, no rejection" in RejectionGrade: 
        print ("no rejection")
        src = ECG_folder + os.sep + folder_name + os.sep + image_name + ".png"
        dst = no_rejection_dir + os.sep + image_name + ".png"

    elif "Grade 1A/1R, mild" in RejectionGrade: 
        print("mild rejection")
        src = ECG_folder + os.sep + folder_name + os.sep + image_name + ".png"
        dst = mild_rejection_dir + os.sep + image_name + ".png"

    elif "Grade 1A/1R, mild rejection" in RejectionGrade: 
        print("mild rejection")
        src = ECG_folder + os.sep + folder_name + os.sep + image_name + ".png"
        dst = mild_rejection_dir + os.sep + image_name + ".png"

    elif "Grade 3A/2R, moderate" in RejectionGrade: 
        print ("moderate-to-severe rejection")
        src = ECG_folder + os.sep + folder_name + os.sep + image_name + ".png"
        dst = moderate_to_severe_rejection_dir + os.sep + image_name + ".png"

    else: 
        print ("unidentified")
        src = ECG_folder + os.sep + folder_name + os.sep + image_name + ".png"
        dst = unidentified_dir + os.sep + image_name + ".png"

    errorsample = AutoCropping(src,dst)
    errorfilelist.append(errorsample)

print(errorfilelist)

## Image Resizing

In [None]:
# Define locations of input
input_dir = "//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//Multiclass_HTx_Rejection_Dataset_1(AutoCropped)//"
output_dir = "//data//Vandenberg-Lab//Internal//Rui//NIDACT2 Transplant ECGs//Multiclass_HTx_Rejection_Dataset_1(300_300)//"

try:
    os.mkdir(output_dir)
except:
    pass

# Input size of network
net_input_size = 300

# Use os.walk to cycle through a folder to grab file names and file paths

folder_list = []

# Loop through list of files for pre-processing
for idx, [path, folder, file_list] in enumerate(os.walk(input_dir)):

    if not folder_list:
        folder_list = folder
        continue

    folder_name = folder_list[idx-1]
    print(folder_name)

    try:
        os.mkdir((output_dir+'//'+folder_name))
    except:
        pass

    for file_name in file_list:
    
        # Check to make sure the file is an image ('.png' extension)
        if not(file_name.endswith('.png')):
            continue

        print(file_name)
        originalImage = Image.open(input_dir + folder_name + '//' + file_name) # Open the file as an 'Image' object
        plt.imshow(originalImage) # Plot the final image
        # plt.show() # Show the plot
        resizedImage = originalImage.resize((net_input_size,net_input_size)) # Resize the image using the resize  function
        plt.imshow(resizedImage) # Plot the final image
        # plt.show() # Show the plot
        greyImage = resizedImage.convert('L') # Convert the image to greyscale with the convert function
        plt.imshow(greyImage) # Plot the final image
        # plt.show() # Show the plot
        outImage = greyImage.convert('RGB') # Convert the image back into a 3-channel image with the convert function (required for network input)
        outImage.save(output_dir + folder_name + '//' + file_name) # Save the final image
        # plt.imshow(outImage) # Plot the final image
        # plt.show() # Show the plot