In [None]:
import persim
import ot.plot
import numpy as np
import matplotlib.pyplot as plt
from skimage import io as skio

class Viewer(object):
    
    @staticmethod
    def imshow_fp_(path):
        '''
        fp: stands for file path
        '''
        img = skio.imread(path)
        imgplot = plt.imshow(img, cmap='gray')
        plt.show()
    
    @staticmethod
    def imshow_(img):
        img = np.squeeze(img)
        imgplot = plt.imshow(img, cmap='gray')
        plt.show()
        
    @staticmethod
    def plot_dgm_wconv_(phc_pd, dim_list):
        dgm_list = Utility_topo.convert_phc_pd_2_persim_batch(phc_pd, dim_list)
        persim.plot_diagrams(dgm_list)
        
    @staticmethod
    def plot_dgm_(dgm_list):
        '''
        dgm_list: batch_size * [structure_num * [x, y]]
        '''
        persim.plot_diagrams(dgm_list)
        
    @staticmethod
    def plot_2dset_match_(set1, set2, match_mat):
        plt.figure()
        ot.plot.plot2D_samples_mat(set1, set2, match_mat, c=[.5, .5, 1])
        plt.plot(set1[:, 0], set1[:, 1], '+g', label='Source samples')
        plt.plot(set2[:, 0], set2[:, 1], 'xr', label='Target samples')
        plt.legend(loc=0)
        plt.title('OT matrix with samples')
        
    @staticmethod
    def show_img_pd_RD_(img_path, pd, coord):
        '''
        === FOR RESEARCH AND DEBUG PURPOSE
        Example usage: Viewer.show_img_pd_(os.path.join(##, ##), pd, coord)
        '''
        assert(len(pd.shape) == 2)
        assert(pd.shape[1] == 2)
        print(coord)
        plt.figure()
        Viewer.plot_dgm_([pd])
        plt.figure()
        Viewer.imshow_fp_(img_path)
        
    @staticmethod
    def rectify_coord(img, x_coord, y_coord):
        '''
        delete coordinates that are black pixels on img.
        img should be original black and white image.
        x_coord/y_coord should be alist of integers.
        '''
        x_coord_r = []
        y_coord_r = []
        for i in range(len(x_coord)):
            if img[y_coord[i]][x_coord[i]] > 0:
                x_coord_r.append(x_coord[i])
                y_coord_r.append(y_coord[i])
        return x_coord_r, y_coord_r
        
    @staticmethod
    def extract_points_set(sets, x_coord, y_coord):
        '''
        sets should have form of sets_num * [point_num * [x, y]]
        '''
        sets_num = len(sets)
        for i in range(sets_num):
            Viewer.extract_points_sequence(sets[i], x_coord, y_coord)
        
    @staticmethod
    def extract_points_sequence(sequence, x_coord, y_coord):
        '''
        sequence should have form of point_num * [x, y].
        '''
        point_num = len(sequence)
        for i in range(point_num):
            x_coord.append(sequence[i][0])
            y_coord.append(sequence[i][1])
            
    @staticmethod
    def draw_bnd_or_red_on_single_dim(img, t, idx=-1, rectify=False, display=True):
        '''
        t should be either bnd, red output from Utility_topo.compute_dist_homology.
        idx is the index of the structure you want to show.
        '''
        x_coord = []
        y_coord = []
        if idx != -1:
            Viewer.extract_points_sequence(t[idx], x_coord, y_coord)
        else:
            Viewer.extract_points_set(t, x_coord, y_coord)
        if rectify == True:
            x_coord, y_coord = Viewer.rectify_coord(img, x_coord, y_coord)
        if display == True:
            imgplot = plt.imshow(img, cmap='gray')
            plt.scatter(x_coord, y_coord, c='r', s=5)
            plt.show()
        return x_coord, y_coord
    
    @staticmethod
    def draw_coord_on_image(x, y, image, value, color=False):
        '''
        Draw coordinates on the image. The image could be grayscale or RGB image.
        x and y are a list of integers representing x and y coordinates.
        value is either an integer or tuple representing the value you want to set
        on the coordinates.
        '''
        plt.figure()
        if len(image.shape) == 2 and color:
            image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
        if len(image.shape) == 2:
            for i in range(len(x)):
                image[int(y[i])][int(x[i])] = value
            imgplot = plt.imshow(image, cmap='gray')
        else:
            for i in range(len(x)):
                image[int(y[i])][int(x[i])][:] = value
            imgplot = plt.imshow(image)
    
    @staticmethod
    def split_cv_bnd_into_x_y(bnd, index):
        '''
        bnd should have form: struct_num * [point_num * [1 * [x ,y]]].
        index is a list of integers with desired structure indices.
        '''
        x_coord = []
        y_coord = []
        for dim in index:
            assert(dim < len(bnd))
            pt_num = len(bnd[dim])
            x_     = [0] * pt_num
            y_     = [0] * pt_num
            for i in range(pt_num):
                x_[i] = bnd[dim][i][0][0]
                y_[i] = bnd[dim][i][0][1]
            x_coord = x_coord + x_
            y_coord = y_coord + y_
        return x_coord, y_coord
    
    @staticmethod
    def draw_bnd_on_single_dim_cv(image, bnd, value, idx=-1):
        '''
        bnd should have form: struct_num * [point_num * [1 * [x ,y]]].
        value: the color be drawn.
        idx: -1 for all, other index for particular structure
        '''
        if len(image.shape) == 2:
            image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
        if idx == -1:
            idx = np.arange(len(bnd))
        else:
            idx = [idx]
        x_coord, y_coord = Viewer.split_cv_bnd_into_x_y(bnd, idx)
        Viewer.draw_coord_on_image(x_coord, y_coord, image, value, color=True)
    
    @staticmethod
    def draw_red_on_single_dim_cv(red, cmap='jet'):
        '''
        red: [0] is the number of labels (connected component), [1] is the matrix
        '''
        plt.figure()
        imgplot = plt.imshow(red[1], cmap=cmap)
        return red[0]