In [None]:
import glob
import os, sys
import numpy
from scipy import misc
import cv2

import matplotlib.pyplot as plt
import matplotlib

% matplotlib inline
sys.path.append('../utilities')
sys.path.append('./ori_tf_version/')
import pickling
import loading
from data import parts_combination

from ori_tf_version import data

In [None]:
directory = '/data_set/Nan/saves/DAGAN_results/'
model = 'patch/group16' + '/experiment/visual_outputs/*' #size64/oridagan #group16
folder = os.path.join(directory, model)
images = glob.glob(folder)
epoch_per_image = [int(x[len(folder)-1:].split('_')[6].split('.')[0]) for x in images]
data_type_list = [x[len(folder)-1:].split('_')[0] for x in images]
print('max epoch:', max(epoch_per_image))
images_list = list(zip(epoch_per_image, images, data_type_list))

images_list_train = [x for x in images_list if x[2] == 'train']
images_list_val = [x for x in images_list if x[2] == 'val']
images_list_test = [x for x in images_list if x[2] == 'test']

In [None]:
def load_interpolations(image_name, im_visual = False):
    img = misc.imread(image)
    img = img.astype(numpy.float32)
    if im_visual:
        plt.figure( figsize = (40, 40))
        plt.imshow(img[:, : 256*8], cmap = 'gray')
        plt.axis('off')
        plt.show()
    return img

def noisy(noise_typ, image, var = 1):
    row, col = image.shape
    
    if noise_typ == "gauss":
        
        mean = 0
        sigma = var**0.5
        gauss = numpy.random.normal(mean,sigma,(row,col))
        gauss = gauss.reshape(row,col)
        noisy = image + gauss
        return noisy
    
    elif noise_typ == "s&p":
        
        s_vs_p = 0.5
        amount = 0.004
        out = numpy.copy(image)
        # Salt mode
        num_salt = numpy.ceil(amount * image.size * s_vs_p)
        coords = [numpy.random.randint(0, i - 1, int(num_salt))
              for i in image.shape]
        out[coords] = 1

        #  Pepper mode
        num_pepper = numpy.ceil(amount* image.size * (1. - s_vs_p))
        coords = [numpy.random.randint(0, i - 1, int(num_pepper))
              for i in image.shape]
        out[coords] = 0
        return out
    
    elif noise_typ == "poisson":
        vals = len(numpy.unique(image))
        vals = 2 ** numpy.ceil(numpy.log2(vals))
        noisy = numpy.random.poisson(image * vals) / float(vals)
        return noisy
    
    elif noise_typ =="speckle":
        gauss = numpy.random.randn(row,col)
        gauss = gauss.reshape(row,col)        
        noisy = image + image * gauss
        return noisy
    
    elif noise_typ == 'none':
        return image
    
class noise_functions():
    def __init__(self, noise_type, ksize = 5, order = [0, 1]):
        func_dict = {'Laplacian': self.Laplacian, 'Sobel': self.Sobel, 'None': self.Ori}
        self.noise_type = noise_type
        self.ksize = ksize
        self.order = order
        self.func = func_dict[noise_type]
        
    def Laplacian(self, image):
        return cv2.Laplacian(image, cv2.CV_64F)
    
    def Sobel(self, image):
        return cv2.Sobel(image, cv2.CV_64F, self.order[0], self.order[1], self.ksize)
    
    def Ori(self, image):
        return image
    
    def __call__(self, image):
        return self.func(image)
        
    
def visualize_single(img, i, j, size = 256, noise_type = 'none', normalize = True, visual = True, var = 1):
    image = img[i*size:i*size + size,  j*size :j*size + size].copy()
    
    if normalize:
        loading.normalise_single_image(image)
    
    image = noisy(noise_type, image, var)   
        
    if visual:
        plt.figure( figsize = (4, 4))
        plt.imshow(image, cmap='gray')
        plt.axis('off')
        
    return image

def compare_with_ori_image(img, i, j, size, noise_type, normalize, i_ori = 0):
    ori_image = visualize_single(img, i_ori, 0, size = size, noise_type = 'none', normalize = normalize, visual = False)
    image = visualize_single(img, i, j, size = size, noise_type = noise_type, var = numpy.var(ori_image) ,normalize = normalize, visual = False)
    return ori_image, image 
#     fig=plt.figure(figsize=(4, 4))
#     fig.add_subplot(1, 2, 1)
#     plt.imshow(ori_image, cmap= 'gray')
#     plt.axis('off')
    
#     fig.add_subplot(1, 2, 2)
#     plt.imshow(image, cmap= 'gray')
#     plt.axis('off')
    
    
#     plt.show()
    

In [None]:
img_list = []
for epoch, image, _ in images_list_test:
    if epoch==58:
        print(image)
        img = load_interpolations(image, True)
        img_list.append(img)
        if len(img_list)>2:
            break

In [None]:
ori_image.shape
image.shape

In [None]:
fig=plt.figure(figsize=(4, 34))
i = 1
for test_img_index in range(int(img_list[0].shape[0]/256)):
    #print(test_img_index)
    image_to_shape = numpy.zeros([256, 256*2])
    ori_image, image  =  compare_with_ori_image(img_list[0], test_img_index, 7, 256, 'none', True, i_ori=test_img_index)
    image_to_shape[:,:256] = ori_image
    image_to_shape[:,256:] = image
    
    fig.add_subplot(16, 1, i)
    plt.imshow(image_to_shape, cmap= 'gray')
    plt.axis('off')

    i+=1
#     fig.add_subplot(16, 2, i)
#     plt.imshow(ori_image, cmap= 'gray')
#     plt.axis('off')
    
#     fig.add_subplot(16, 2, i+1)
#     plt.imshow(image, cmap= 'gray')
#     plt.axis('off')
    
#     i+=2
    
    
plt.show()
    

In [None]:
#### data_parameters = {'patches_directory':'/data_set/Nan/data/segmentation_patches/DAGAN_binary_parts_size_2000/',
              'training_data_parts': 2,
              'validation_data_parts': 1,
              'test_data_parts': 1}  

x_train = parts_combination(data_parameters['patches_directory'], 'training', data_parameters['training_data_parts'])
x_val = parts_combination(data_parameters['patches_directory'], 'validation', data_parameters['validation_data_parts'])
x_test = parts_combination(data_parameters['patches_directory'], 'test', data_parameters['test_data_parts'])

In [None]:
x_train[1].shape

In [None]:
b= x_train[1][:,::2,::2,:].copy() 
b = b[:,::2,::2,:].copy() 

In [None]:
max_num_of_image = x_train[1].shape[0]
for i in range(100):
    plt.imshow(b[i,:,:,0], cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
max_num_of_image = x_train[1].shape[0]
for i in range(100):
    plt.imshow(x_train[0][i,:,:,0], cmap='gray')
    plt.axis('off')
    plt.show()