In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import time
import numpy as np
from six import iteritems
import shelve
from scipy.ndimage.measurements import label
import h5py

import PIL
import os
from IPython.display import Image, display
from IPython.html.widgets import interact, fixed
from IPython.html import widgets
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt

import os
from os.path import expanduser
import time

import torchvision
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, ConcatDataset
from torch.nn.parallel import data_parallel
from torchvision import transforms


import importlib
import augment
import loss
import pyramid
import masks


importlib.reload(augment)
importlib.reload(masks)
importlib.reload(loss)
importlib.reload(pyramid)
importlib.reload(masks)


from pyramid import Pyramid, create_pyramid
from loss import similarity_score, smoothness_penalty, get_dataset_loss, get_mse_and_smoothness_masks
from masks import get_defect_mask, get_white_mask
from augment import Translation, RandomTranslation, ToFloatTensor, BlackBox, GreyBox, SawtoothBoundary, RandomWarp, CropMiddle, Downsample, combine_residuals, res_warp_img, res_warp_res
from augment import set_up_transformation, generate_transform, apply_transform
from loss import similarity_score, smoothness_penalty, smoothness_score
from helpers import reverse_dim, downsample

def get_np(pt):
    return pt.cpu().detach().numpy()

def prepare_for_show(img):
    if isinstance(img, torch.Tensor):
         img = get_np(img)
    img = img.squeeze()
    return img


def display_image(img, x_coords=None, y_coords=None, normalize=False, figsize=(10, 10), mask=False):
    if normalize and mask:
        raise Exception("Masks can't be normalized")
    img = prepare_for_show(img)
    plt.figure(figsize=figsize)
    
    if x_coords is None:
        x_coords = [0, img.shape[0]]
    if y_coords is None:
        y_coords = [0, img.shape[1]]
    
    if mask and False:
        plt.imshow(img[x_coords[0]:x_coords[1], y_coords[0]:y_coords[1]], cmap='gray', vmin=0, vmax=1.0)
    elif not normalize:
        plt.imshow(img[x_coords[0]:x_coords[1], y_coords[0]:y_coords[1]], cmap='gray', vmin=-0.5, vmax=0.5)
    else:
        plt.imshow(img[x_coords[0]:x_coords[1], y_coords[0]:y_coords[1]], cmap='gray')


class SectionVisualizer(object):
    def __init__(self, data_file, masks_file=None, masker_net=None, pairs=False):
        self.dataset = h5py.File(data_file, 'r')['main']
        if len(self.dataset.shape) > 3 and self.dataset.shape[0] != 1:
                self.dataset = self.dataset[0:1]
        if masks_file is not None:
            self.masks = h5py.File(masks_file, 'r')['main'][...]
            if len(self.masks.shape) > 3 and self.masks.shape[0] != 1:
                self.masks = self.masks[0:1]
            self.masks = self.masks.astype(np.float32)
            print (self.masks.shape)
            print (self.dataset.shape)
        else:
            self.masks = None
         
        self.masker_net = masker_net
        self.last_id = -1
        self.pairs = pairs
        
    def loadimg(self, img_id, choice, cropped=0):
        
        
        if img_id != self.last_id:
            self.last_id = img_id
            if self.pairs:
                self.img = self.dataset[img_id, 0]
                self.true_mask = torch.cuda.FloatTensor(self.masks[img_id, 0])
            else:
                self.img = self.dataset[..., img_id, :, :].squeeze()
                self.true_mask = torch.cuda.FloatTensor(self.masks[..., img_id, :, :]).squeeze()
                
                
            if cropped != 0:
                side = self.img.shape[0]
                start = (side-cropped) // 2
                end = start + cropped
                self.img = self.img[start:end, start:end]
                self.true_mask = self.true_mask[..., start:end, start:end]
                
            if self.masker_net is not None:
                self.img_var = torch.cuda.FloatTensor(self.img) / 255. - 0.5
                self.pred_mask = self.masker_net(self.img_var.unsqueeze(0).unsqueeze(0))
                self.pred_mask_np = get_np(self.pred_mask).squeeze() > 0.9
                self.filtered_np = filter_connected_component(self.pred_mask_np, 30)
        
        if choice == 'img':
            img_norm = self.img
            '''values = np.unique(self.img)
            print (values)
            lo = values[len(values) // 98]
            hi = values[len(values) - len(values) // 98 - 1]
            lo = 255. * 0.0
            hi = 255. * 1.65
            
            img_norm = np.copy(self.img).astype(np.float32)
            img_norm = np.copy(self.img).astype(np.float32)
            #img_norm -= lo
            #img_norm /= (hi - lo)
            
            img_norm[self.img <= lo] = 0
            img_norm[self.img >= hi] = 1
            #img_norm[(self.img > lo) * (self.img <= hi)] -= img_norm[(self.img > lo) * (self.img <= hi)].mean()
            #img_norm[(self.img > lo) * (self.img <= hi)] /= img_norm[(self.img > lo) * (self.img <= hi)].var()**0.5
            if (img_norm.min() < 0):
                img_norm -= img_norm.min()
            img_norm /= img_norm.max()
            img_norm *= 255'''
            
            display_image(img_norm, normalize=True)
        elif choice == 'tru':
        
            display_image((self.true_mask), normalize=True)
        elif choice == 'filtered mask':
            display_image(self.filtered_np, normalize=False)
        elif choice == 'net mask':
            display_image(self.pred_mask, normalize=True)
        elif choice == 'loss':          
            print ("image: ", torch.mean(torch.abs(img_var)))
            print ("pred mask: ", torch.mean(pred_mask))
            print ("print true mask: ", torch.mean(true_mask))
            l = loss.mask_diff_loss()(self.pred_mask, self.true_mask)
            print ("loss: ", l)
            false_negative = (self.true_mask - self.pred_mask) > 0
            display_image(false_negative, normalize=False)
            
            
    def visualize(self, section_count=1, state_file=None, cropped_size=0):
        self.id_selector = widgets.IntText(
            value=0,
            description='Sample ID:',
            disabled=False
        )

       
        buttons = ['img', 'net mask', 'tru', 'filtered mask']

        # for supervised
        #buttons += ['True Residual', 'Error histogram', 'Vector histogram', 'Residual Error'],
        self.button_choice_1 = widgets.ToggleButtons(
            options=buttons,
            description='Image:',
            disabled=False,
            button_style='',
        )
        interact(self.loadimg, img_id=self.id_selector, choice=self.button_choice_1, cropped=cropped_size)


            
class simple_visualizer():
    def __init__(self):
        pass
    
    def display_multiimg(self, choice):
        i = self.names.index(choice)
        if isinstance(self.images[i], torch.Tensor):
            self.images[i] = self.images[i].cpu().detach().numpy()
        self.images[i].squeeze()
        
        plt.figure(figsize = (12,12))
        if self.crop == 0:
            plt.imshow(self.images[i], cmap='gray', vmin=-0.5, vmax=0.5)
        else:
            plt.imshow(self.images[i][self.crop:-self.crop, self.crop:-self.crop], cmap='gray')
        
    def visualize(self, images, names=None, crop=0):
        if names is None:
            names = range(len(images))
        self.names = names
        self.images = images
        self.crop = crop
        
        button_choice = widgets.ToggleButtons(
                options=names,
                description='Image:',
                disabled=False,
                button_style='',
            )

        interact(self.display_multiimg, choice=button_choice)

from scipy.ndimage.measurements import label

def filter_connected_component(array, N):
    indices = np.indices(array.shape).T[:,:,[1, 0]]
    result = np.copy(array)
    structure = np.ones((3, 3), dtype=np.int)
    labeled, ncomponents = label(array, structure)
    for i in range(ncomponents):
        my_guys = indices[labeled == i]
        if (len(my_guys) < N):
            for coord in my_guys:
                result[coord[0], coord[1]] = False
    return result



ImportError: No module named 'pyramid'

# Visualize

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="6"

m = None

#viz1 = SectionVisualizer("/usr/people/popovych/aligner/ncc_m7_x1.h5"
#                         , "/usr/people/popovych/aligner/ncc_m7_x1.h5", None)
#viz1 = SectionVisualizer("/usr/people/popovych/aligner/ncc_fromnormed_m8_z22500_x1.h5"
#                         , "/usr/people/popovych/aligner/ncc_fromnormed_m8_z22500_x1.h5", None)
viz1 = SectionVisualizer(
                        #"/usr/people/popovych/natmet_data/proofreading/cutout_x0_omni/data/img.h5"
                        #"/usr/people/popovych/natmet_data/proofreading/minnie_p3_x206000_y185000_z17550_17650_img.h5"
                        #, "/usr/people/popovych/natmet_data/proofreading/minnie_p3_x206000_y185000_z17550_17650_aff.h5",
                        "/usr/people/popovych/seungmount/Omni/TracerTasks/minnie/Minnie_GT_ForAlignmentQualityAssessment/vol3/cutout_x3_omni/data/img.h5"
                        , "/usr/people/popovych/seungmount/Omni/TracerTasks/minnie/Minnie_GT_ForAlignmentQualityAssessment/vol3/cutout_x3_omni/data/seg.h5",
                         #, "/usr/people/popovych/natmet_data/proofreading/cutout_x1_omni/data/seg.h5",
                         
                         None, pairs=False)


In [2]:
from cloudvolume import CloudVolume as cv

In [6]:
mip = 4
c = cv("precomputed://gs://seunglab_minnie_phase3/sergiy/final_fine_x4/image_stitch_dd200_mip1_final", mip=4)

In [8]:

m0_start = (131315, 140431, 20563)
m0_size = 512 * mip**2
m0_end = [m0_start[0] + m0_size, m0_start[1] + m0_size, m0_start[2] + 1]


c[m0_start[0]//(2**mip):m0_end[0]//(2**mip), 
  m0_start[1]//(2**mip):m0_end[1]//(2**mip), 
  m0_start[2]//(2**mip):m0_end[2]]

OutOfBoundsError: Value 131315 cannot be outside of inclusive range -320 to 31200

In [None]:
viz1.visualize()

### data_h5.close()
defect_h5.close()

In [8]:
data_h5 = h5py.File("minnie_v1_full_mip8_1536px_train.h5", 'r')
data_dset = data_h5['main']
dataset_dim = list(data_dset.shape)
print (dataset_dim)
fold_h5 = h5py.File("defect_net1_minnie_v1_full_mip8_1536px_train.h5", 'r')
fold_dset = fold_h5['main']
defect_h5 = h5py.File("defect_net_with_cracks2_minnie_v1_full_mip8_1536px_train.h5", 'w')
if 'main' in defect_h5:
    defect_dset = defect_h5['main']
else:
    defect_dset = defect_h5.create_dataset("main", dataset_dim)

for i in range(0, dataset_dim[1]):
    #print (i)
    #sample = np.flip(np.rot90(defect_gt_dset[i], k=1), axis=0)
    img = data_dset[0, i]
    fold_np = fold_dset[0, i]
    img_var = torch.cuda.FloatTensor(img) / 255. - 0.5
    crack_mask = m(img_var.unsqueeze(0).unsqueeze(0))
    crack_mask_np = get_np(crack_mask).squeeze() > 0.9
    filtered_crack_np = filter_connected_component(crack_mask_np, 30)
    defect_dset[0, i] = np.logical_or(filtered_crack_np, fold_np)
    print (i)

defect_h5.close()

[1, 400, 1536, 1536]
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
2

In [7]:
defect_gt_h5 = h5py.File("minnie_v1_full_mip8_1536px_train_edges.h5", 'r')
defect_gt_dset = defect_gt_h5['main']
dataset_dim = [1] + list(defect_gt_dset.shape)
print (dataset_dim)
defect_h5 = h5py.File("plastic_mask2_minnie_v1_full_mip8_1536px_train.h5", 'w')
if 'main' in defect_h5:
    defect_dset = defect_h5['main']
else:
    defect_dset = defect_h5.create_dataset("main", dataset_dim)

for i in range(0, dataset_dim[1]):
    print (i)
    sample = np.flip(np.rot90(defect_gt_dset[i], k=1), axis=0)
    defect_dset[0, i] = sample

defect_h5.close()

[1, 400, 1536, 1536]
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
2

In [25]:
defect_h5.close()
defect_gt_h5.close()

In [5]:
for i in range(0, 0):
    print (i)
    vec_dset[i] = viz1.get_residuals(img_id=i, val=False, sup=False)[0].cpu().detach().numpy()

In [6]:
vec_h5.close()

NameError: name 'vec_h5' is not defined

In [None]:
print (a.shape)

In [None]:
dataset_mip = 8
dataset_m8 = compile_dataset([{"data": {"h5": "./minnie_full_mip8_1300px_train.h5", "mip": 8}, 
                                        "masks": {
                                            "edges": {"h5": "./edge_mask_minnie_full_mip8_1300px_train.h5", "mip": 8},
                                            "defects": {"h5": "./defect_mask_greedy93_minnie_full_mip8_1300px_train.h5", "mip": 8},
                                            "plastic": {"h5": "plastic_mask_minnie_full_mip8_1300px_train.h5", "mip": 8}
                                        }}], supervised=False, max_size=500)
set_up_transformation({}, dataset_m8, 0, 0)

In [None]:
tdef = {8: [{"type": "preprocess"}]}
tdef = {}
s8 = 200
e8 = 240
s6 = s8 * 4 - 2
e6 = e8 * 4 - 2

In [None]:
sample_m6 = viz1.unsup_train_dataset[9].unsqueeze(0)[:, :, s6:e6, s6:e6] / 255.0 - 0.5
m6_trans = generate_transform(tdef, 6, 8)
processed_m6 = apply_transform_to_sample(sample_m6, m6_trans).unsqueeze(0)

In [None]:
sample_m8 = dataset_m8[9].unsqueeze(0)[:, :, s8:e8, s8:e8] / 255.0 - 0.5
m8_trans = generate_transform(tdef, 8, 8)
processed_m8 = apply_transform_to_sample(sample_m8, m8_trans).unsqueeze(0)

In [None]:
import scipy.ndimage

src_m6 = sample_m6[0, 0]
src_p_m6 = scipy.ndimage.zoom(processed_m6[0, 0], 4, order=0)

src_m8 = scipy.ndimage.zoom(sample_m8[0, 0], 4, order=0)
src_p_m8 = scipy.ndimage.zoom(processed_m8[0, 0], 4, order=0)
simple_visualizer().visualize([src_m6, src_p_m6, src_m8])

In [None]:
len(viz1.dataset[0][0])