In [2]:
import numpy as np
import scipy.ndimage
import ot
import math
import copy
import os
import glob
from skimage import io as skio

class Utility_general(object):
    
    @staticmethod
    def nz2_nearest_z(t):
        return scipy.ndimage.morphology.distance_transform_edt(t)
    
    @staticmethod
    def flatten_image_batch(data):
        '''
        data should be np array with shape (batch_size, height, width)
        returns a list: batch_size * [(height*width,)]
        '''
        assert(len(data.shape) == 3)
        assert(type(data[0][0,0] == 'np.uint8'))
        bat_size  = data.shape[0]
        trfm_list = [None] * bat_size
        for i in range(bat_size):
            trfm_list[i] = np.concatenate(data[i,:])
        return trfm_list
    
    @staticmethod
    def read_image_subset(dir_in, ext_in, percentage, shuffle=True):
        '''
        Read in images in grayscale.
        NOTE, this function is copied from FILEIO to avoid call depth limit of python
        '''
        address_book_in = []
        os.chdir(dir_in)
        for file in glob.glob("*."+ext_in):
            address_book_in.append(os.path.join(dir_in, file))
        file_num = len(address_book_in)
        read_num = int(np.floor(file_num * percentage))
        if shuffle:
            ind_list = random.sample(range(file_num), read_num)
        else:
            ind_list = np.arange(read_num)
            
        res = [None] * read_num
        for i in range(read_num):
            res[i] = cv2.imread(address_book_in[ind_list[i]], cv2.IMREAD_GRAYSCALE)
        print("Data value range: %f - %f" %(np.amax(res[0]), np.amin(res[0])))
        return res
    
    @staticmethod
    def find_closest_N_points(p, point_list, N, opposite2=False, connect=False, shape=None, thickness=1):
        '''
        Find N closest points to p from point_list, there is no duplicate in the
        returned list. Algorithm iterate through first N points and WON'T fill N
        if there are duplicates in the first N.
        ===== inputs
        p: the critical point, [x, y]
        point_list: the points to be searched through, point_num * [[x, y]]
        N: number of nearest points to find, if opposite2 specified, N is ignored, integer
        opposite2: if to find only two opposite points for the p, there may be less than 2, bool
        connect: if to connect the found points and the p, bool
        shape: shape of the input images, needs to be specified if connect = True, (2,)
        thickness: the thickness of the connecting line segment, integer
        '''
        # note that: point_list is assumed opencv countour[X] strucure
        # use point_list[12][0][0] and point_list[12][0][1] to access
        point_num = len(point_list)
        assert(N <= point_num)
        dist = [0.0] * point_num
        for i in range(point_num):
            dist[i] = (p[0] - point_list[i][0][0]) ** 2 + (p[1] - point_list[i][0][1]) ** 2          
        increasing_sorted_index = np.argsort(dist)      
        # =======================================
#         sorted_dist = np.sort(dist, axis=0)
#         print(sorted_dist)
        # =======================================   
        x_coord = []
        y_coord = []
        if not opposite2:
            for i in range(N):
                idx = increasing_sorted_index[i]
                if i-1 >= 0 and dist[idx] == dist[increasing_sorted_index[i-1]]:
                    if (point_list[idx][0][0] == point_list[increasing_sorted_index[i-1]][0][0] and
                        point_list[idx][0][1] == point_list[increasing_sorted_index[i-1]][0][1]):
                            continue
                x_coord.append(int(point_list[idx][0][0]))
                y_coord.append(int(point_list[idx][0][1]))
        else:
            if (point_list[increasing_sorted_index[0]][0][0] == p[0] and
                point_list[increasing_sorted_index[0]][0][1] == p[1]):
                if (point_list[increasing_sorted_index[1]][0][0] == p[0] and
                    point_list[increasing_sorted_index[1]][0][1] == p[1]):
                    x_coord.append(int(p[0]))
                    y_coord.append(int(p[1]))
                else:
                    x_coord.append(int(p[0]))
                    y_coord.append(int(p[1]))
                    x_coord.append(int(point_list[increasing_sorted_index[1]][0][0]))
                    y_coord.append(int(point_list[increasing_sorted_index[1]][0][1]))
            else:
                x_coord.append(int(point_list[increasing_sorted_index[0]][0][0]))
                y_coord.append(int(point_list[increasing_sorted_index[0]][0][1]))
                desired_x_sign = -np.sign(x_coord[0] - p[0])
                desired_y_sign = -np.sign(y_coord[0] - p[1])
                for i in range(1, point_num):
                    idx = increasing_sorted_index[i]
                    current_x_sign = np.sign(point_list[idx][0][0] - p[0])
                    current_y_sign = np.sign(point_list[idx][0][1] - p[1])
                    if ((desired_x_sign==0 and current_y_sign==desired_y_sign) or
                        (desired_y_sign==0 and current_x_sign==desired_x_sign) or
                        (desired_x_sign!=0 and desired_y_sign!=0 and desired_x_sign==current_x_sign and desired_y_sign==current_y_sign)):
                        x_coord.append(int(point_list[idx][0][0]))
                        y_coord.append(int(point_list[idx][0][1]))
                        break
        if connect:
            x_additional = []
            y_additional = []
            for i in range(len(x_coord)):
                ends = np.array([[x_coord[i], y_coord[i]], [int(p[0]), int(p[1])]])
                x_, y_ = Utility_general.connect_2pts(ends, shape, thickness, split_form=True)
                x_additional = x_additional + x_
                y_additional = y_additional + y_
            x_coord = x_coord + x_additional
            y_coord = y_coord + y_additional
        return x_coord, y_coord
                
    @staticmethod
    def normalize_data_(data, reg):
        num_ = len(data)
        for i in range(num_):
            data[i] = (data[i] - reg) / reg
        return data
    
    @staticmethod
    def wasserstein_1D_(a, b, p, wa=None, wb=None):
        ''' 
        compute wasserstein distance between 1d distributions
        a: source distribution, wa: weight of samples from a
        b: target distribution, wb: weight of samples from b
        p: order of wasserstein distance
        '''
        dist = ot.wasserstein_1d(a, b, wa, wb, p)
        G    = ot.emd_1d(a, b, wa, wb, metric='minkowski', p=p)
        return dist, G

    @staticmethod
    def wasserstein_2D_(a, b, p, wa=None, wb=None):
        '''
        compute wasserstein distance between 2d distributions
        a: source distribution in form of [N x 2], wa: weight of samples from a in form of [N,]
        b: source distribution in form of [M x 2], wb: weight of samples from b in form of [M,]
        p: order of wasserstein distance
        '''
        assert(len(a.shape) == 2 and len(b.shape) == 2)
        num_a = a.shape[0]
        num_b = b.shape[0]
        if wa != None and wb != None:
            assert(wa.shape[0] == num_a and wb.shape[0] == num_b)
        else:
            wa = ot.unif(num_a)
            wb = ot.unif(num_b)
        if p == 1.0:
            metric = "euclidean"
        elif p == 2.0:
            metric = "sqeuclidean"
        else:
            raise NotImplementedError('Unrecognized p: %f' % p)
        M = ot.dist(a, b, metric=metric)
        G = ot.emd(wa, wb, M)
        dist = np.sum(G * M)

        if p == 1.0:
            return dist, G
        elif p == 2.0:
            return math.sqrt(dist), G
    
    @staticmethod
    def wasserstein_set_distance(set1, set2, p):
        '''
        compute wasserstein distance between two 1D/2D sets: set1, set2.
        program will automatically detect the dimension of the inputs
        == input values:
        set1/set2 should have form [(X1,), (X2,), ...] in case of 1D and
        [(X1, 2), (X2, 2), ...] in case of 2D. outer loop is a list with
        each element a numpy array with shape (X,) in 1D and (X, 2) in 2D.
        == return values:
        dist: M X N double matrix
        G: M X N list with each element a matching matrix
        '''
        if len(set1[0].shape) == 1 and len(set2[0].shape) == 1:
            print("Computing 1D %d-wasserstein distance." %int(p))
            oneD_flag = True
        elif len(set1[0].shape) == 2 and len(set2[0].shape) == 2:
            print("Computing 2D %d-wasserstein distance." %int(p))
            oneD_flag = False
        else:
            raise NotImplementedError('wasserstein_set_distance: invalid inputs')

        M    = len(set1)
        N    = len(set2)
        dist = np.zeros((M, N))
        G    = []

        for i in range(M):
            G_row = [None] * N
            for j in range(N):
                if oneD_flag:
                    d_, g_ = Utility_general.wasserstein_1D_(set1[i], set2[j], p)
                else:
                    d_, g_ = Utility_general.wasserstein_2D_(set1[i], set2[j], p)
                dist[i, j] = d_
                G_row[j]   = g_
            G.append(G_row)
        return dist, G
    
    @staticmethod
    def connect_2pts(ends, shape, thickness=1, split_form=False):
        '''
        find all points between two points. there is NO duplicate points and two 
        end points are included in the outputs.
        ===== inputs
        ends: two end points, (2, 2)
        shape: the height and weight of the underlying image, (2,)
        thickness: thickness of the line segment, integer
        split_form: whether to output x, y coordinate together or in two lists, bool
        ===== outputs
        none split form: fil_res, (point_num, 2)
        split form: fil_x/fil_y, list of integers
        '''
        d0, d1 = np.diff(ends, axis=0)[0]
        if np.abs(d0) > np.abs(d1): 
            if d1 == 0:
                res = np.c_[np.arange(np.min((ends[0,0], ends[1,0])), np.max((ends[0,0], ends[1,0])), 1, dtype=np.int32),
                            np.ones(np.abs(d0), dtype=np.int32) * ends[0, 1]]
            else:
                res = np.c_[np.arange(ends[0, 0], ends[1,0] + np.sign(d0), np.sign(d0), dtype=np.int32),
                         np.arange(ends[0, 1] * np.abs(d0) + np.abs(d0)//2,
                                   ends[0, 1] * np.abs(d0) + np.abs(d0)//2 + (np.abs(d0)+1) * d1, d1, dtype=np.int32) // np.abs(d0)]        
            for i in range(thickness - 1):
                dup = copy.copy(res)
                if i % 2 == 0:
                    dup[:,1] = dup[:,1] + 1
                else:
                    dup[:,1] = dup[:,1] - 1
                res = np.vstack((res, dup))
        else:
            if d0 == 0:
                res = np.c_[np.ones(np.abs(d1), dtype=np.int32) * ends[0, 0],
                            np.arange(np.min((ends[0,1], ends[1,1])), np.max((ends[0,1], ends[1,1])), 1, dtype=np.int32)]
            else:
                res = np.c_[np.arange(ends[0, 0] * np.abs(d1) + np.abs(d1)//2,
                                   ends[0, 0] * np.abs(d1) + np.abs(d1)//2 + (np.abs(d1)+1) * d0, d0, dtype=np.int32) // np.abs(d1),
                         np.arange(ends[0, 1], ends[1,1] + np.sign(d1), np.sign(d1), dtype=np.int32)]
            for i in range(thickness - 1):
                dup = copy.copy(res)
                if i % 2 == 0:
                    dup[:,0] = dup[:,0] + 1
                else:
                    dup[:,0] = dup[:,0] - 1
                res = np.vstack((res, dup))

        # ===== filter the results =====
        h = shape[0]
        w = shape[1]
        end1_in = False
        end2_in = False
        if split_form:
            fil_x = []
            fil_y = []
        else:
            fil_res = []
        for i in range(res.shape[0]):
            if res[i,0] >= 0 and res[i,0] < w and res[i,1] >= 0 and res[i,1] < h:
                if res[i,0] == ends[0,0] and res[i,1] == ends[0,1]:
                    end1_in = True
                if res[i,0] == ends[1,0] and res[i,1] == ends[1,1]:
                    end2_in = True
                if split_form:
                    fil_x.append(res[i,0])
                    fil_y.append(res[i,1])
                else:
                    fil_res.append(res[i,:])
        if split_form:
            if not end1_in:
                fil_x.append(ends[0,0])
                fil_y.append(ends[0,1])
            if not end2_in:
                if not (ends[0,0] == ends[1,0] and ends[0,1] == ends[1,1]):
                    fil_x.append(ends[1,0])
                    fil_y.append(ends[1,1])
            return fil_x, fil_y
        else:
            if not end1_in:
                fil_res.append(ends[0,:])
            if not end2_in:
                if not (ends[0,0] == ends[1,0] and ends[0,1] == ends[1,1]):
                    fil_res.append(ends[1,:])
            return np.asarray(fil_res)