# WorkFlow of MSI processing



In [1]:
import h5py
import re
import cv2
import os
import imageio
import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyimzml.ImzMLParser import ImzMLParser

from csbdeep.models import CARE
from PIL import Image
from ISR.models import RRDN

from miplib.data.containers.image import Image as resImage
import miplib.analysis.resolution.fourier_ring_correlation as frc
from miplib.data.containers.fourier_correlation_data import FourierCorrelationDataCollection
import miplib.ui.plots.frc as frcplots
from miplib.ui.cli import miplib_entry_point_options as options


## Class for interior format to store and preprocess raw data

This object also contains basic operations: save2h5, preprocess, visual

In [8]:
class Data(object):
    """
    A structure to store the data to be preprocessed
    """
    def __init__(self, data, x, y):
        self.mzdata = data
        self.xcoord = np.array(x)
        self.ycoord = np.array(y)
        self.dim_xy = np.array([max(self.xcoord), max(self.ycoord)])
        self.common_mz = self.get_commonMZ()
        self.lenmz = len(self.common_mz)
        self.counts_mat = self.get_countsMatrix()
        self.backgrd = self._get_background()
        
    def visual(self, mz = 256):
        num = len(self.xcoord)
        image = np.zeros((self.dim_xy[1], self.dim_xy[0]))

        for i in range(num):
            
            position = np.where(np.ceil(self.mzdata[i]['mz']) == np.ceil(mz))[0]
            if len(position) > 0:
                image[self.ycoord[i]-1][self.xcoord[i]-1] = sum(self.mzdata[i]['count'][position])
        fig = plt.figure(figsize = (5,5))
        p = plt.imshow(image)
        plt.show()
        #cbar = fig.colorbar(p, orientation='vertical', ticks=[0, np.max(image)], shrink = 0.8)

    def get_commonMZ(self):
        common_mz = np.array([])
        for idx in self.mzdata:
            common_mz = np.concatenate((common_mz, np.around(self.mzdata[idx]["mz"], decimals=1)))
        common_mz = np.unique(common_mz)   
        return common_mz 

    def _get_background(self):
        # Use the lowest mean value along mz axis as the background noise (which ion image shows minimum intensities)
        mean_spectrum = np.mean(self.counts_mat.reshape(self.dim_xy[1]*self.dim_xy[0], self.lenmz), axis=0) #len(self.common_mz)), axis=0)
        low_mz = np.argmin(mean_spectrum)
        bg = np.mean(self.counts_mat[:, :, low_mz])
        return bg

    def _get_snr(self, image):         ########################### from Andy
        mean_value = np.mean(image)
        snr = 20*np.log10(mean_value / self.backgrd)
        return snr

    def _cal_snr(self, noise_img, clean_img): #################### from blog
        # How to define noise_img?
        noise_signal = noise_img - clean_img
        clean_signal = clean_img

        sum1 = np.sum(clean_signal**2)
        sum2 = np.sum(noise_signal**2)
        snr = 20*np.log10(np.sqrt(sum1) / np.sqrt(sum2))
        return snr

    def filter_lowsnr(self, criteria = 35): 
        new_idx = []
        for idx in range(self.counts_mat.shape[2]):
            tmp_image = self.counts_mat[:, :, idx]
            if self._get_snr(tmp_image) > criteria:
                new_idx.append(idx)
        new_counts_mat = np.array(self.counts_mat[:, :, new_idx])
        new_cmz = np.array(self.common_mz[new_idx])

        self.counts_mat = new_counts_mat
        self.common_mz = new_cmz

        return new_cmz, new_counts_mat
                
    def get_countsMatrix(self):
        mz_len = self.lenmz #len(self.common_mz)
        counts_mat = np.zeros((self.dim_xy[1], self.dim_xy[0], mz_len))
        for idx in range(mz_len):
            for i in range(len(self.xcoord)):
                position = np.where(np.around(self.mzdata[i]["mz"], decimals=1) == self.common_mz[idx])[0]
                if len(position) > 0 and len(self.mzdata[i]["count"]) == len(self.mzdata[i]["mz"]):
                    counts_mat[self.ycoord[i]-1, self.xcoord[i]-1, idx] = sum(self.mzdata[i]["count"][position])
        return counts_mat
    
    def save2h5(self, outPath):
        f = h5py.File(outPath, 'w')
        
        for idx in self.mzdata:
            f.create_dataset('mz_' + str(idx), data = self.mzdata[idx]["mz"].astype(np.float))
            f.create_dataset('counts_' + str(idx), data = self.mzdata[idx]["count"].astype(np.float))

        f.create_dataset('x_coord', data = np.array(self.xcoord, dtype = np.float))
        f.create_dataset('y_coord', data = np.array(self.ycoord, dtype = np.float))
        f.close()

    def get_mask(self):
        img = np.sum(self.counts_mat, axis=2)
        img = np.array((img / np.max(img))*255)
        # tmp_img = img
        img = np.array(Image.fromarray(img).convert('L'))
        # img = cv2.GaussianBlur(img, (5, 5), 0)
        kernel = np.ones((5,5))

        iterations = 10
        for i in range(iterations):
            edge = cv2.Canny(img, 10, 100)
            edge = np.array(edge)   
            dilate = cv2.dilate(edge, kernel, iterations = 1)
            erosion = cv2.erode(dilate, kernel, iterations = 1)

            img1, contours, hierarchy=cv2.findContours(erosion,cv2.RETR_CCOMP,cv2.CHAIN_APPROX_NONE)
            image = cv2.drawContours(img, contours, 0, (255, 255, 255), -1)
            break

        res_img = np.zeros((img.shape[0], img.shape[1]))
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                if img[i][j] == 255:
                    res_img[i][j] = 255 #tmp_img[i][j]

        return res_img
    

In [3]:
class TempData(object):
    """
    A structure to store the data to be preprocessed
    """
    def __init__(self, dim_xy, cz, counts_mat):
        
        self.dim_xy = dim_xy
        self.common_mz = cz
        self.lenmz = len(self.common_mz)
        self.counts_mat = counts_mat
        self.backgrd = self._get_background()
        
    def visual(self, mz = 256):
        num = len(self.xcoord)
        image = np.zeros((self.dim_xy[1], self.dim_xy[0]))

        for i in range(num):
            
            position = np.where(np.ceil(self.mzdata[i]['mz']) == np.ceil(mz))[0]
            if len(position) > 0:
                image[self.ycoord[i]-1][self.xcoord[i]-1] = sum(self.mzdata[i]['count'][position])
        fig = plt.figure(figsize = (5,5))
        p = plt.imshow(image)
        plt.show()
        #cbar = fig.colorbar(p, orientation='vertical', ticks=[0, np.max(image)], shrink = 0.8)

    def get_commonMZ(self):
        common_mz = np.array([])
        for idx in self.mzdata:
            common_mz = np.concatenate((common_mz, np.around(self.mzdata[idx]["mz"], decimals=1)))
        common_mz = np.unique(common_mz)   # 8002
        return common_mz #[:151]

    def _get_background(self):
        # Use the lowest mean value along mz axis as the background noise (which ion image shows minimum intensities)
        mean_spectrum = np.mean(self.counts_mat.reshape(self.dim_xy[1]*self.dim_xy[0], self.lenmz), axis=0) #len(self.common_mz)), axis=0)
        low_mz = np.argmin(mean_spectrum)
        bg = np.mean(self.counts_mat[:, :, low_mz])
        return bg

    def _get_snr(self, image):         ########################### from Andy
        mean_value = np.mean(image)
        snr = 20*np.log10(mean_value / self.backgrd)
        return snr

    def _cal_snr(self, noise_img, clean_img): #################### from blog
        # How to define noise_img?
        noise_signal = noise_img - clean_img
        clean_signal = clean_img

        sum1 = np.sum(clean_signal**2)
        sum2 = np.sum(noise_signal**2)
        snr = 20*np.log10(np.sqrt(sum1) / np.sqrt(sum2))
        return snr

    def filter_lowsnr(self, criteria = 35): # or we can combine this inside model processing (save time)
        new_idx = []
        for idx in range(self.counts_mat.shape[2]):
            tmp_image = self.counts_mat[:, :, idx]
            if self._get_snr(tmp_image) > criteria:
                new_idx.append(idx)
        new_counts_mat = np.array(self.counts_mat[:, :, new_idx])
        new_cmz = np.array(self.common_mz[new_idx])

        self.counts_mat = new_counts_mat
        self.common_mz = new_cmz

        return new_cmz, new_counts_mat
                
    def get_countsMatrix(self):
        mz_len = self.lenmz #len(self.common_mz)
        counts_mat = np.zeros((self.dim_xy[1], self.dim_xy[0], mz_len))
        for idx in range(mz_len):
            for i in range(len(self.xcoord)):
                position = np.where(np.around(self.mzdata[i]["mz"], decimals=1) == self.common_mz[idx])[0]
                if len(position) > 0 and len(self.mzdata[i]["count"]) == len(self.mzdata[i]["mz"]):
                    counts_mat[self.ycoord[i]-1, self.xcoord[i]-1, idx] = sum(self.mzdata[i]["count"][position])
        return counts_mat
    
    def save2h5(self, outPath):
        f = h5py.File(outPath, 'w')
        
        for idx in self.mzdata:
            f.create_dataset('mz_' + str(idx), data = self.mzdata[idx]["mz"].astype(np.float))
            f.create_dataset('counts_' + str(idx), data = self.mzdata[idx]["count"].astype(np.float))

        f.create_dataset('x_coord', data = np.array(self.xcoord, dtype = np.float))
        f.create_dataset('y_coord', data = np.array(self.ycoord, dtype = np.float))
        f.close()

    def get_mask(self):
        img = np.sum(self.counts_mat, axis=2)
        img = np.array((img / np.max(img))*255)
        # tmp_img = img
        img = np.array(Image.fromarray(img).convert('L'))
        # img = cv2.GaussianBlur(img, (5, 5), 0)
        kernel = np.ones((5,5))

        #iterations = 1
        #for i in range(iterations):
        edge = cv2.Canny(img, 20, 50)
        edge = np.array(edge)   
        dilate = cv2.dilate(edge, kernel, iterations = 1)
        erosion = cv2.erode(dilate, kernel, iterations = 1)

        contours, hierarchy=cv2.findContours(erosion,cv2.RETR_CCOMP,cv2.CHAIN_APPROX_NONE)
        image = cv2.drawContours(img, contours, 0, (255, 255, 255), -1)

        res_img = np.zeros((img.shape[0], img.shape[1]))
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                if img[i][j] == 255:
                    res_img[i][j] = 255 #tmp_img[i][j]

        return res_img
    

## Class for model processing

This object contains two trainend DL models for MS imaging enhancement.

In [4]:
class Process(object):
    """
    The processing of the input image data through our trained models (GAN + UNET5)
    """
    def __init__(self):
        self.gan_weights = "D:/BMR-DS/Project_2/Code/weights/rrdn-C4-D3-G32-G032-T10-x4_best-val_generator_loss_epoch051.hdf5"
        self.unet5_weights = "D:/BMR-DS/Project_2/Code/weights"

        self.unet5Model = self.unet5_load_weights()
        self.gan_load_weights()

    def unet5_load_weights(self):
        #print("Loading UNET5 weights...")
        unet5Model = CARE(config = None, name = '400s_40ep_0.5n', basedir = self.unet5_weights)
        print("Done.")
        return unet5Model

    def gan_load_weights(self):
        self.ganModel = RRDN(arch_params={'C': 4, 'D':3, 'G':32, 'G0':32, 'T': 10, 'x':4})
        print("Loading GAN weights...")
        self.ganModel.model.load_weights(self.gan_weights)
        print("Done.")
        
    def gan_enhance(self, image, saveFlag = False, outPath = '', format = 'png'):
        # image is 3-D version
        #print("Processing single image by ESRGAN...")
        gan_image = self.ganModel.predict(image)

        if saveFlag and format == '.png':
            self.save2png(gan_image, outPath)

        return gan_image

    def unet5_denoise(self, image, saveFlag = False, outPath = ''):
        # image is 2-D version
        #print("Processing single image by UNET5...")
        unet5_image = self.unet5Model.predict(image, axes = 'YX')

        if saveFlag and format == '.png':
            self.save2png(unet5_image, outPath, format)

        return unet5_image

    def gan_unet5(self, image, saveFlag = False, outPath = '', format = '.png'):
        # image is 3-D version
        print("Processing single image by ESRGAN and UNET5...")
        gan_image = self.ganModel.predict(image)
        unet5_image = self.unet5Model.predict(gan_image[:, :, 0], axes = 'YX')
        
        if saveFlag and format == '.png':
            self.save2png(unet5_image, outPath)

        return unet5_image

    def save2png(self, image, outPath): ## image name??
        saved_image = Image.fromarray(image)
        saved_image.save(outPath)

    def plot_original_image(self, image):
        plt.figure(figsize=(10, 8))
        plt.imshow(image[:,:,0])
        plt.title("Original Image",fontsize = 30)

    def plot_gan_image(self, image):
        plt.figure(figsize=(10, 8))
        plt.imshow(image[:, :, 0])
        plt.title("ESRGAN Image",fontsize = 30)

    def plot_unet5_image(self, image):
        plt.figure(figsize=(10,8))
        plt.imshow(image[:,:])
        plt.title("UNET5 Image",fontsize = 30)
    
    def plot_all(self, image, gan_image, unet5_image):
        plt.figure(figsize=(10, 8))

        plt.subplot(131)
        plt.imshow(image[:, :, 0])
        plt.title("Original Image")

        plt.subplot(132)
        plt.imshow(gan_image[:, :, 0])
        plt.title("ESRGAN Image")

        plt.subplot(133)
        plt.imshow(unet5_image[:,:])
        plt.title("ESRGAN+UNET5 Image")

## Class FRC to calculate resolution using FRC

In [5]:
class FRC(object):
    def __init__(self, spacing = [100.0, 100.0]):
        self.spacing = spacing # physical size of each pixel (um)
        #self.image = resImage(image - image.min(), self.spacing)
        self.args = self.setup()

    def setup(self):
        args_list = ("None --bin-delta=1  --frc-curve-fit-type=smooth-spline "  
             " --resolution-threshold-criterion=fixed").split() # half-bit

        args = options.get_frc_script_options(args_list)

        return args
        
    def image2Image(self, image):
        image = resImage(image, self.spacing)
        return image

    def get_resolution(self, image):
        # image is in resImage format
        image = self.image2Image(image)
        frc_results = FourierCorrelationDataCollection()

        frc_results[0] = frc.calculate_single_image_frc(image, self.args)
        #frcplots.plot_resolution_curves(frc_results,size = (5,5))
        resolution = frc_results[0].resolution["resolution"]

        return resolution

    def plot_curve_without_fitting(self, image):
        # Cannot get the resolution
        image = self.image2Image(image)
        frc_results = FourierCorrelationDataCollection()

        frc_results = frc.calculate_single_image_frc_without_fit(image, self.args)
        X = frc_results[0].correlation["frequency"]
        Y = frc_results[0].correlation["correlation"]
        print("Frequency: \n {}".format(X))
        print("Correlation: \n {}".format(Y))

        # Plot the result without fitting
        plt.figure(figsize=(5,5))
        plt.plot(X, Y)

## Some functions to read or convert different format data

In [6]:
def imzML_to_h5(inPath, outPath):
    p = ImzMLParser(inPath)
    f = h5py.File(outPath, 'w')
    
    xcoord = []
    ycoord = []
    for idx, (x,y,z) in enumerate(p.coordinates):
        mzs, intensities = p.getspectrum(idx)
        f.create_dataset('mz_' + str(idx), data = mzs.astype(np.float))
        
        f.create_dataset('counts_' + str(idx), data = intensities.astype(np.float))
        xcoord.append(x)
        ycoord.append(y)
    f.create_dataset('x_coord', data = np.array(xcoord, dtype = np.float))
    f.create_dataset('y_coord', data = np.array(ycoord, dtype = np.float))
    f.close()

def imzML_to_predata(inPath):
    p = ImzMLParser(inPath)
    data = dict()

    xcoord = []
    ycoord = []
    for idx, (x,y,z) in enumerate(p.coordinates):
        mzs, intensities = p.getspectrum(idx)
        data[idx] = dict()
        data[idx]["count"] = np.array(intensities.astype(np.float))
        data[idx]["mz"] = np.array(mzs.astype(np.float))
        xcoord.append(x)
        ycoord.append(y)
    PreData = Data(data, xcoord, ycoord)
    return PreData

def h5_to_predata(inPath):
    f = h5py.File(inPath, 'r')
    x = np.array(f['x_coord']).astype(np.int64)
    y = np.array(f['y_coord']).astype(np.int64) ### !!!!!

    data = dict()
    for key in f.keys():
        if re.match("counts_", key):
            idx = int(key[7:])
            data[idx] = dict()
            
            data[idx]['count'] = np.array(f[key])
            data[idx]['mz'] = np.array(f['mz_'+str(idx)])
    PreData = Data(data, x, y)
    f.close()
    return PreData

def countsmatrix_to_h5(counts_mat, common_mz, dim_xy, outPath):
    f = h5py.File(outPath, 'w')
    f.create_dataset("dim_xy", data = np.array(dim_xy, dtype = np.int))
    f.create_dataset("cmz", data = np.array(common_mz, dtype = np.float))
    f.create_dataset("counts_mat", data = np.array(counts_mat, dtype = np.float))
    f.close()

def selectROI(img, mask):
    ## check img.size == mask.size
    res_img = np.zeros((img.shape[0], img.shape[1]))
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            if mask[i][j] == 255:
                res_img[i][j] = img[i][j]
    return res_img

## Main processing cell

In [None]:
if __name__ == "__main__":
    start = time.time()

    # The path can be changed to your own.
    fdPath = "D:/BMR-DS/Project_2/DataSet/Test Set/Test Set H5"
    outfdPath = "D:/BMR-DS/Project_2/DataSet/Test Set/Test Set Images"
    resfdPath = "D:/BMR-DS/Project_2/DataSet/Test Set/Test Set Resolution"
    fList = os.listdir(fdPath)

    Model = Process()
    resFRC = FRC()
    for fName in fList:
        spacing = [100.0, 100.0]

        # It works for the test set naming format, you can change for fitting your aim.
        # Aim is to obtain the pixel size of the raw data.
        if fName != "A30.h5":
            infos = fName.split("_")
            psize = infos[6].split("-")
            spacing = [float(psize[0]), float(psize[1])]

        fPath = os.path.join(fdPath, fName)
        outPath = os.path.join(outfdPath, fName.replace(".imzml", ".h5"))
        resPath = os.path.join(resfdPath, fName.replace(".imzml", ".csv"))

        Predata = imzML_to_predata(fPath)
        print("input data ready!")
        common_mz, counts_mat = Predata.filter_lowsnr()
        
        mask = Predata.get_mask()
        print("filter and mask done!")

        enhanced_counts_mat = np.zeros((Predata.dim_xy[1], Predata.dim_xy[0], Predata.lenmz))
        result = pd.DataFrame()
        original = []
        gan = []
        gan_unet5 = []
        index = []

        for idx in range(len(common_mz)):
            image = np.zeros((counts_mat.shape[0], counts_mat.shape[1], 3))
            temp = np.zeros((counts_mat.shape[0], counts_mat.shape[1]))
            for i in range(temp.shape[0]):
                for j in range(temp.shape[1]):
                    if mask[i][j] == 255:
                        temp[i][j] = counts_mat[i,j,idx]
            maxmz = np.max(temp)
            temp = np.array((temp / maxmz)*255)
            image[:,:,0] = temp
            image[:,:,1] = temp
            image[:,:,2] = temp
            image = np.array(image)

            gan_image = Model.gan_enhance(image)
            unet_image = Model.unet5_denoise(gan_image[:,:,0])
            unet_image = np.array(unet_image)
            
            if idx == 0:
                enhanced_counts_mat = np.zeros((unet_image.shape[0], unet_image.shape[1], Predata.lenmz))
            enhanced_counts_mat[:,:,idx] = unet_image*(maxmz/np.max(unet_image))

            
            # Calculate the resolution by FRC
            try:
                original.append(resFRC.get_resolution(image[:, :, 0]))
                gan.append(resFRC.get_resolution(gan_image[:, :, 0]))
                gan_unet5.append(resFRC.get_resolution(unet_image[:, :]))
                index.append(common_mz[idx])
            except:
                continue
        print("images restoration done!")

        result['mz'] = index
        result['original'] = original
        result['ESRGAN'] = gan
        result['ESRGAN_UNET5'] = gan_unet5
        result.to_csv(resPath)
        countsmatrix_to_h5(enhanced_counts_mat, common_mz, Predata.dim_xy, outPath)

        break
    print("Time: {}.".format(time.time()-start))

## Some code for multiprocessing

If your PC or server can support the multiple processing, you can try and change the code to save execution time.

In [None]:
#import multiprocessing
def load_process(fPath, outPath, resPath):
    # read data
    f = h5py.File(fPath, 'r')
    counts_mat = np.array(f["counts_mat"])
    dim_xy = np.array(f["dim_xy"])
    common_mz = np.array(f["cmz"])
    f.close()
    
    # set up for spacing
    if fName != "A30.h5":
        infos = fName.split("_")
        psize = infos[6].split("-")
        spacing = [float(psize[0]), float(psize[1])]
    
    Predata = TempData(dim_xy, common_mz, counts_mat)
    common_mz, counts_mat = Predata.filter_lowsnr()
    
    # variables for results
    enhanced_counts_mat = np.zeros((Predata.dim_xy[1], Predata.dim_xy[0], Predata.lenmz))
    result = pd.DataFrame()
    original = []
    gan = []
    gan_unet5 = []
    index = []
    print("Total ion images : {}.".format(len(common_mz)))
    
    for idx in range(len(common_mz)):
          
        image = np.zeros((counts_mat.shape[0], counts_mat.shape[1], 3))
        tmp = counts_mat[:,:,idx]
        maxmz = np.max(tmp)
        tmp = np.array((tmp / maxmz)*255)
        
        image[:,:,0] = tmp #counts_mat[:,:,idx]
        image[:,:,1] = tmp #counts_mat[:,:,idx]
        image[:,:,2] = tmp #counts_mat[:,:,idx]
        image = np.array(image)
        #maxmz = np.max(image[:,:,0])
        
        # model restoration
        gan_image = Model.gan_enhance(image)
        unet_image = Model.unet5_denoise(gan_image[:,:,0])
        unet_image = np.array(unet_image)
        unet_image[unet_image<0] = 0
        
        if idx == 0:
            enhanced_counts_mat = np.zeros((unet_image.shape[0], unet_image.shape[1], Predata.lenmz))
        
        try:
            original.append(resFRC.get_resolution(image[:, :, 0], spacing))
            gan.append(resFRC.get_resolution(gan_image[:, :, 0], spacing))
            gan_unet5.append(resFRC.get_resolution(unet_image[:, :], spacing))
            index.append(common_mz[idx])
        except:
            continue
          
    result['mz'] = index
    result['original'] = original
    result['ESRGAN'] = gan
    result['ESRGAN_UNET5'] = gan_unet5
    result.to_csv(resPath)
    countsmatrix_to_h5(enhanced_counts_mat, common_mz, Predata.dim_xy, outPath)
        
if __name__ == "__main__": 
    fdPath = "D:/BMR-DS/Project_2/DataSet/Test Set/Test Set H5"
    outfdPath = "D:/BMR-DS/Project_2/DataSet/Test Set/Test Set Images"
    resfdPath = "D:/BMR-DS/Project_2/DataSet/Test Set/Test Set Resolution"
    fList = os.listdir(fdPath)
    
    s = time.time()
    spacing = [100.0, 100.0]
    #_processes = [] # multiple Process set up!
    Model = Process()
    resFRC = FRC()
    
    
    for fName in fList:
        if fName.find(".ipynb_checkpoints")>=0:
            continue
        print("{} is processing...".format(fName))
        fPath = os.path.join(fdPath, fName)
        outPath = os.path.join(outfdPath, fName)
        resPath = os.path.join(resfdPath, fName.replace(".h5", ".csv"))
        load_process(fPath, outPath, resPath)
        print("Done! Time: {}.".format(time.time()-s))
        #_process = multiprocessing.Process(target=load_process, args=(fPath, outPath, resPath))
        #_process.start()
        #_processes.append(_process)
    #for _process in _processes:
    #    _process.join()

    print(time.time()-s)
    with open("DataSet/time.txt", "a") as f:
        f.write(str(time.time()-s))

    

'''
from multiprocessing import Pool

def load_process(fName):
    if 
    fPath = os.path.join(fdPath, fName)
    f = h5py.File(fPath, 'r')
    counts_mat = np.array(f["counts_mat"])
    dim_xy = np.array(f["dim_xy"])
    common_mz = np.array(f["cmz"])
    f.close()
    
if __name__ == "__main__": 
    startT = time.time()
    fdPath = "DataSet/Test Set H5"
    fList = os.listdir(fdPath)
    
    pool = Pool()
    pool.map(load_process, fList)
    pool.close()
    pool.join()
    print(time.time()-startT)
'''

## Single image processing -- from raw to resolution

Only extract one image from raw data

The MS image is generated from choosing one peak (mz=256) from raw data cube.

In [10]:
def process_single_image(inPath, mz = 255):
    if re.match("(.*)\.h5$", inPath):
        Predata = h5_to_predata(inPath)
    elif re.match("(.*)\.imzml$", inPath):
        Predata = imzML_to_predata(inPath)
    
    #Predata.visual()

    # Model processing
    Model = Process()
    image = Predata.transfer2image(mz)
    gan_image = Model.gan_enhance(image)
    unet_image = Model.unet5_denoise(gan_image[:,:,0])

    # Plot all images
    Model.plot_all(image, gan_image, unet_image)
    
    # Calculate resolution
    resFRC = FRC()
    ori = resFRC.get_resolution(image[:, :, 0])
    enh = resFRC.get_resolution(gan_image[:, :, 0])
    den = resFRC.get_resolution(unet_image)

    return (ori, enh, den)

## Multiple images processing -- from images to resolution

The input is a folder of .png format images.

In [None]:
def process_multiple_images(fdrPath, save = False, outPath = None):
    frc_results = dict()
    resFRC = FRC()

    flist = os.listdir(fdrPath)
    for f in flist:
        fpath = os.path.join(fdrPath, f)
        image = imageio.imread(fpath)
        try:
            frc_results[f] = resFRC.get_resolution(image)
        except:
            print("{} cannot be sovled.".format(f))        

    if save:
        result = pd.DataFrame()
        result["file name"] = frc_results.keys()
        result["resolution"] = frc_results.values()
        result.to_excel(outPath)
        
    return frc_results

## Raw data processing -- from raw to resolution

The input is the imzML format raw data.

First, transfer the imzML data to numpy_image; Then, process and solve the numpy_image.

How to get "count_matrix.h5", normalise the mz vector. (different types of tissue need the same method)

In [None]:
'''
def process_raw_data(inPath):
    if re.match("(.*)\.h5$", inPath):
        Predata = h5_to_predata(inPath)
    elif re.match("(.*)\.imzml$", inPath):
        Predata = imzML_to_predata(inPath)

    Model = Process()
    dataset = Predata
    enh_dataset = dict()
    resolution_dataset = dict()
    for i in dataset:
        image = np.array(dataset[i])
        enh_image = Model.gan_unet5(image)
        enh_dataset[i] = enh_image
        
        resFRC = FRC()
        resolution_dataset[i] = resFRC.get_resolution(enh_image[:, :])

    return enh_dataset, resolution_dataset
'''
def process_raw_data(inPath):
    if re.match("(.*)\.h5$", inPath):
        Predata = h5_to_predata(inPath)
    elif re.match("(.*)\.imzml$", inPath):
        Predata = imzML_to_predata(inPath)

    Model = Process()
       