In [3]:
# Setup -- from CS 231N GANs notebook
from __future__ import print_function, division
import tensorflow as tf
import numpy as np

import os
import nibabel as nib
from nibabel.testing import data_path


import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'



In [14]:
# A bunch of utility functions

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return

def show_multimodal(image):
    # images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    figdim = int(np.ceil(image.shape[0]/2))
    # sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
    
    dispInd = int(np.ceil(image.shape[3]/2))

    fig = plt.figure(figsize=(figdim, figdim))
    gs = gridspec.GridSpec(figdim, figdim)
    gs.update(wspace=0.05, hspace=0.05)

    for i in range(image.shape[0]):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(image[i,:,:,dispInd])
    return

def preprocess_img(x):
    return 2 * x - 1.0

def deprocess_img(x):
    return (x + 1.0) / 2.0

def rel_error(x,y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

def count_params():
    """Count the number of parameters in the current TensorFlow graph """
    param_count = np.sum([np.prod(x.get_shape().as_list()) for x in tf.global_variables()])
    return param_count


def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session

In [62]:
# Open 
path = '../HGG/Brats17_2013_2_1/Brats17_2013_2_1_flair.nii'
flair_img = nib.load(path).get_data()

path = '../HGG/Brats17_2013_2_1/Brats17_2013_2_1_t1.nii'
t1_img = nib.load(path).get_data()

path = '../HGG/Brats17_2013_2_1/Brats17_2013_2_1_t2.nii'
t2_img = nib.load(path).get_data()

path = '../HGG/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii'
t1ce_img = nib.load(path).get_data()

path = '../HGG/Brats17_2013_2_1/Brats17_2013_2_1_seg.nii'
seg_img = nib.load(path).get_data()

# img = np.zeros((4, 240, 240, 155))
# img[0,:,:,:] = flair_img
# img[1,:,:,:] = t1_img
# img[2,:,:,:] = t2_img
# img[3,:,:,:] = t1ce_img

# show_multimodal(img)

0


In [73]:
# Compress files into numpy compressed form --- maybe use later
i = 0
for root, dirs, files in os.walk('../HGG/'):
    if len(files) > 1:
        HGGnp = np.zeros((4,240,240,155))

        flair_img = nib.load(root + '/' + files[-5]).get_data()
        HGGnp[0,:,:,:] = flair_img
        
        seg = nib.load(root + '/' + files[-4]).get_data()
        HGGsegnp = seg
        
        t1_img = nib.load(root + '/' + files[-3]).get_data()
        HGGnp[1,:,:,:] = flair_img
        
        t1ce_img = nib.load(root + '/' + files[-2]).get_data()
        HGGnp[2,:,:,:] = flair_img
        
        t2_img = nib.load(root + '/' + files[-1]).get_data()
        HGGnp[3,:,:,:] = flair_img
        
        np.savez_compressed(root + '/' + files[-1][:-7], img = HGGnp, seg = HGGsegnp)
        
        i +=1
        print(i)


# HGG = tf.Variable(tf.stack(HGGnp), name="HGG", trainable=False)

i = 0
for root, dirs, files in os.walk('../LGG/'):
    if len(files) > 1:
        LGGnp = np.zeros((4,240,240,155))

        flair_img = nib.load(root + '/' + files[-5]).get_data()
        LGGnp[0,:,:,:] = flair_img
        
        seg = nib.load(root + '/' + files[-4]).get_data()
        LGGsegnp = seg
        
        t1_img = nib.load(root + '/' + files[-3]).get_data()
        LGGnp[1,:,:,:] = flair_img
        
        t1ce_img = nib.load(root + '/' + files[-2]).get_data()
        LGGnp[2,:,:,:] = flair_img
        
        t2_img = nib.load(root + '/' + files[-1]).get_data()
        LGGnp[3,:,:,:] = flair_img
        
        np.savez_compressed(root + '/' + files[-1][:-7], img = LGGnp, seg = LGGsegnp)
        
        i +=1
        print(i)

# LGG = tf.Variable(tf.zeros([75,4,240,240,155]), name="LGG", trainable=False)

        
        

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161


KeyboardInterrupt: 

In [None]:
LGGnp = np.zeros((75,4,240,240,155))
i = 0
for root, dirs, files in os.walk('../LGG/'):
    if len(files) > 1:
        LGGnp = np.zeros((4,240,240,155))

        flair_img = nib.load(root + '/' + files[-5]).get_data()
        LGGnp[0,:,:,:] = flair_img
        
        seg = nib.load(root + '/' + files[-4]).get_data()
        LGGsegnp = seg
        
        t1_img = nib.load(root + '/' + files[-3]).get_data()
        LGGnp[1,:,:,:] = flair_img
        
        t1ce_img = nib.load(root + '/' + files[-2]).get_data()
        LGGnp[2,:,:,:] = flair_img
        
        t2_img = nib.load(root + '/' + files[-1]).get_data()
        LGGnp[3,:,:,:] = flair_img
        
        np.savez_compressed(root + '/' + files[-1][:-7], img = LGGnp, seg = LGGsegnp)
        
        i +=1
        print(i)
        
LGG = tf.Variable(tf.stack(LGGnp), name="LGG", trainable=False)

In [None]:
# Data feeding function

def BRATSdatafeed()
    
    



In [None]:
# Dice Loss
def DICEscore(truth, estimate)
    #
    # truth, estimate = tf tensors of size N x voxels*channels
    # Computes average Dice scores
    #
    # loss = scalar tensor containing Dice loss 
    # 
    
    loss = 2*tf.reduce_mean(tf.divide(tf.reduce_sum(tf.multiply(truth,estimate),axis=1),
                       (tf.reduce_sum(truth,axis=1) + tf.reduce_sum(estimate,axis=1))))
    return loss

In [114]:
def brainconvnet(x):
    """Build computational graph 
    
    Inputs:
    - x: TensorFlow Tensor of flattened input images, shape [batch_size, #voxels * 4]
    IMPORTANT: x must be a channels-first tensor (before resizing)
    
    Returns:
    TensorFlow Tensor with shape [batch_size, #voxels * 5]
    (Graph calculates 5th order tensor: channel 1 for image, channel 2 for class, others for voxels)
    """
    with tf.variable_scope("convNet"):
        input_size = x.get_shape().as_list()
        x_reshape = tf.reshape(x, [-1,240,240,155,4])
        conv1 = tf.layers.conv3d(inputs=x_reshape, filters=8, 
                                 kernel_size=[11, 11, 11],padding="same", activation=tf.nn.relu)
        pool1 = tf.layers.max_pooling3d(inputs = conv1, pool_size = (2,2,2),
                                        strides = (2,2,2), padding='valid',name=None)
        
        conv2 = tf.layers.conv3d(inputs=pool1, filters=8, 
                                 kernel_size=[5, 5, 5],padding="same", activation=tf.nn.relu)
        pool2 = tf.layers.max_pooling3d(inputs = conv2, pool_size = (2,2,2),
                                        strides = (2,2,2), padding='valid',name=None)
        
        conv3 = tf.layers.conv3d(inputs=pool2, filters=32, 
                                 kernel_size=[3, 3, 3],padding="same", activation=tf.nn.relu)
        pool3 = tf.layers.max_pooling3d(inputs = conv3, pool_size = (2,2,2),
                                        strides = (2,2,2), padding='valid',name=None)
        
        conv4 = tf.layers.conv3d(inputs=pool3, filters=128, 
                                 kernel_size=[3, 3, 3],padding="same", activation=tf.nn.relu)
        
#         return conv4
            
    with tf.variable_scope("deconvNet"):
        W3 = tf.Variable(tf.truncated_normal([3, 3, 3, 32, 128], stddev=0.1))
        deconv3 = tf.nn.conv3d_transpose(conv4, filter = W3, output_shape = [input_size[0],60, 60, 36, 32], 
                                         strides = [1,1,1,1,1])
        b3 = tf.Variable(tf.constant(0.1, shape=[32]))
        relu3 = tf.nn.relu(deconv3 + b3)
        
        W2 = tf.Variable(tf.truncated_normal([3, 3, 3, 8, 32], stddev=0.1))
        deconv2 = tf.nn.conv3d_transpose(relu3, filter = W2, output_shape = [input_size[0],120, 120, 72, 8], 
                                         strides = [1,1,1,1,1])
        b2 = tf.Variable(tf.constant(0.1, shape=[8]))
        relu2 = tf.nn.relu(deconv2 + b2)
        
        W1 = tf.Variable(tf.truncated_normal([3, 3, 3, 5, 8], stddev=0.1))
        deconv1 = tf.nn.conv3d_transpose(relu2, filter = W1, output_shape = [input_size[0],240, 240, 155, 5], 
                                         strides = [1,1,1,1,1])
        b1 = tf.Variable(tf.constant(0.1, shape=[5]))
        img = tf.reshape(deconv1 + b1,[-1,5*240*240*155])
        
        return img 

In [115]:
# Get number of network parameters
def paramcount():
    
    tf.reset_default_graph()
    with get_session() as sess:
        y = brainconvnet(tf.ones((2, 240*240*155*4)))
        cur_count = count_params()
        return cur_count

print(paramcount())

286901


In [None]:
tf.reset_default_graph()

batch_size = 128
# our noise dimension
noise_dim = 96

# placeholders for images from the training dataset
x = tf.placeholder(tf.float32, [None, 784])
z = sample_noise(batch_size, noise_dim)
# generated images
G_sample = generator(z)

with tf.variable_scope("") as scope:
    #scale images to be -1 to 1
    logits_real = discriminator(preprocess_img(x))
    # Re-use discriminator weights on new inputs
    scope.reuse_variables()
    logits_fake = discriminator(G_sample)

# Get the list of variables for the discriminator and generator
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'discriminator')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'generator') 

D_solver,G_solver = get_solvers()
D_loss, G_loss = gan_loss(logits_real, logits_fake)
D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS,'generator')