In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import scipy.io as sio
%matplotlib inline
from PIL import Image
import os
import sys
import cv2
import time

## Make sure that caffe is on the python path:
caffe_root = '../../'
sys.path.insert(0, caffe_root + 'python')
import caffe

In [2]:
data_root = '../../data/BSDS500/'
with open(data_root+'test_gpu.lst') as f:
    test_lst = f.readlines()
    
img_lst = [x.strip().split()[0] for x in test_lst]
sp_lst = [x.strip().split()[1] for x in test_lst]

In [3]:
## Principal Component Analysis (PCA)
def pca(data, n):
    newData = data - np.mean(data, axis=0)
    covMat = np.cov(newData, rowvar=0)
    eigVals,eigVects = np.linalg.eig(np.mat(covMat))
    eigValIndice = np.argsort(eigVals)            
    n_eigValIndice = eigValIndice[-1:-(n+1):-1] 
    n_eigVect = eigVects[:,n_eigValIndice]        
    lowDDataMat = newData*n_eigVect
    
    return lowDDataMat

In [4]:
## Visualization
def plot_single_scale(scale_lst, size):
    pylab.rcParams['figure.figsize'] = size, size/2
    plt.figure()
    for i in range(0, len(scale_lst)):
        s=plt.subplot(1,2,i+1)
        plt.imshow(scale_lst[i])
        s.set_xticklabels([])
        s.set_yticklabels([])
        s.yaxis.set_ticks_position('none')
        s.xaxis.set_ticks_position('none')
    plt.tight_layout()

In [5]:
color_table = np.array([ [128,   0,   0], [  0, 128,   0], [128, 128,   0], [  0,   0, 128], [128,   0, 128],
        [  0, 128, 128], [128, 128, 128], [ 64,   0,   0], [192,   0,   0], [ 64, 128,   0], [192, 128,   0],
        [ 64,   0, 128], [192,   0, 128], [ 64, 128, 128], [192, 128, 128], [  0,  64,   0], [128,  64,   0],
        [  0, 192,   0], [128, 192,   0], [  0,  64, 128] ])

In [6]:
## remove the following two lines if testing with cpu
caffe.set_mode_gpu()
caffe.set_device(1)
## load net
net = caffe.Net('test.prototxt', 'snapshots/image_segment_net_iter_10000.caffemodel', caffe.TEST)

save_root = '/home/liuyun/LiuYun/Code/DEL/data/seism/datasets/BSDS500/DeepSeg3/' #'../../data/BSDS500/DeepSeg2/'
if not os.path.exists(save_root):
    os.mkdir(save_root)

In [7]:
start_time = time.time()
bound = range(350,351)

for i in range(len(bound)):
    threshold = bound[i] / 100.0
    if threshold <= 0.3:
        min_size = [0, 1, 2, 3]
    else:
        min_size = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    min_size = [3]
    for j in range(len(min_size)):
        new_save_root = os.path.join(save_root, 'threshold_{:.2f}_{}'.format(threshold, min_size[j]))
        if not os.path.exists(new_save_root):
            os.mkdir(new_save_root)

        for idx in range(len(test_lst)):
            ## load and prepare an image
            im = Image.open(data_root+img_lst[idx])
            in_ = np.array(im, dtype=np.float32)
            in_ = in_[:,:,::-1]
            in_ -= np.array((104.00698793,116.66876762,122.67891434))
            in_ = in_.transpose((2,0,1))
            sp_ = cv2.imread(data_root+sp_lst[idx], cv2.IMREAD_ANYDEPTH)
            sp_ = sp_.astype(dtype=np.float32)
            ## shape for input (data blob is N x C x H x W), set data
            net.blobs['data'].reshape(1, *in_.shape)
            net.blobs['data'].data[...] = in_
            net.blobs['sp_label'].reshape(1, 1, *sp_.shape)
            net.blobs['sp_label'].data[...] = sp_
            net.blobs['bound_param'].reshape(1, 1, 1, 1)
            net.blobs['bound_param'].data[...] = threshold 
            net.blobs['minsize_param'].reshape(1, 1, 1, 1)
            net.blobs['minsize_param'].data[...] = min_size[j]
            net.forward()
            ## segment the image in the network
            out3 = net.blobs['segmentation'].data[0].copy()
            out3 = out3.transpose((1, 2, 0)).astype(dtype=np.uint16)
            #out4 = cv2.resize(out3, (out3.shape[1]*2+1, out3.shape[0]*2+1), interpolation = cv2.INTER_NEAREST)
            #cv2.imwrite(new_save_root + '/' + img_lst[idx][11:-4] + '.png', out3)
            sio.savemat(new_save_root + '/' + img_lst[idx][11:-4] + '.mat', {'Segmentation': out3})            
diff_time = time.time() - start_time
print 'Detection took {:.3f}s per image'.format(diff_time/len(test_lst))

Detection took 15.128s per image
