In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
from glob import glob
import math
import matplotlib.pyplot as plt
import nibabel as nb
from nilearn import image
import numpy as np
import os
import pandas as pd
import random
import tensorflow as tf
import timeit
import warnings

  from ._conv import register_converters as _register_converters


In [2]:
def lrelu(x, a=0.1):
    return tf.maximum(a*x, x)

def G_conv(batch_input, out_channels):
    return tf.layers.conv3d(batch_input, out_channels, kernel_size=4, strides=(2, 2, 2), padding="valid")

def D_conv(batch_input, out_channels):
    return tf.layers.conv3d(batch_input, out_channels, kernel_size=4, strides=(1, 1, 1), padding="same")

def D_max_pool(batch_input):
    return tf.layers.max_pooling3d(batch_input, 2, 2)

def G_conv_transpose(batch_input, out_channels):
    return tf.layers.conv3d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2, 2), padding="valid")

def batchnorm(batch_input):
    return tf.layers.batch_normalization(batch_input, epsilon=1e-5, momentum=0.1, training=True)

def generator(G_in, G_out_channels):
    with tf.variable_scope("generator"):
        layers = []
        G_in = tf.expand_dims(G_in, -1)
        # encoder_1
        with tf.variable_scope("encoder_1"):
            conv_out = G_conv(G_in, ngf)

    #        output = gen_conv(G_in, ngf)
            layers.append(conv_out)

        layer_nfilters = [
            ngf, # encoder_2
            ngf * 2, # encoder_3
            ngf * 4, # encoder_4
        ]

        for out_n in layer_nfilters:
            with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
                rectified = lrelu(layers[-1], 0.1)
                # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                convolved = G_conv(G_in, out_n)
                output = batchnorm(convolved)
                layers.append(output)

        layer_specs = [
#            ngf * 4,   # decoder_4
#            ngf * 2,   # decoder_3
            ngf,       # decoder_2
        ]

        num_encoder_layers = len(layers)
        for decoder_layer, out_channels in enumerate(layer_specs):
            skip_layer = num_encoder_layers - decoder_layer - 1
            with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
    #             if decoder_layer == 0:
    #                 # first decoder layer doesn't have skip connections
    #                 # since it is directly connected to the skip_layer
    #                 decoder_input = layers[-1]
    #             else:
    #                 decoder_input = tf.concat([layers[-1], layers[skip_layer]], axis=4)
                decoder_input = layers[-1]
                rectified = tf.nn.relu(decoder_input)

                decoder_output = G_conv_transpose(rectified, out_channels)

                output = batchnorm(decoder_output)
                
                layers.append(output)

        # decoder_1
#         with tf.variable_scope("decoder_1"):
#             #dec_1_input = tf.concat([layers[-1], layers[0]], axis=4)
#             dec_1_input = layers[-1]
#             rectified = tf.nn.relu(dec_1_input)
#             output = G_conv_transpose(rectified, G_out_channels)
#             output = tf.tanh(output)
#             layers.append(output)

        return layers[-1]

def site_discriminator(D_input):
    with tf.variable_scope("site_discriminator"):
        n_layers = 2
        layers = []

        #D_input = tf.concat([discrim_inputs, discrim_targets], axis=4)

        # layer_1:
#        D_input = tf.expand_dims(D_input, -1)
        with tf.variable_scope("layer_1"):
            convolved = D_conv(D_input, ndf)
            pooled = D_max_pool(convolved)
            rectified = lrelu(pooled, 0.1)
            layers.append(rectified)

        # layer_2:
        # layer_3:
        # layer_4:
        for i in range(n_layers):
            with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                out_channels = ndf/(2*(i+1))

                #stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
                convolved = D_conv(layers[-1], out_channels)
                pooled = D_max_pool(convolved)
                normalized = batchnorm(pooled)
                rectified = lrelu(normalized, 0.1)
                layers.append(rectified)

        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            convolved = D_conv(rectified, out_channels=1)
            fc1 = tf.contrib.layers.flatten(convolved)
#            fc1 = tf.layers.dense(fc1, 300)
#            fc1 = lrelu(fc1)
            # Output layer, class prediction
            out = tf.layers.dense(fc1, 17)
            layers.append(out)
    #         pred_classes = tf.argmax(out, axis=1)
    #         layers.append(pred_classes)
        return layers[-1]

def qc_discriminator(D_input):
    with tf.variable_scope("qc_discriminator"):
        n_layers = 2
        layers = []

    #    D_input = tf.concat([discrim_inputs, discrim_targets], axis=4)
    #    D_input = tf.expand_dims(D_input, -1)

        # layer_1:
        with tf.variable_scope("layer_1"):
            convolved = D_conv(D_input, ndf)
            pooled = D_max_pool(convolved)
            rectified = lrelu(pooled, 0.1)
            layers.append(rectified)

        # layer_2:
        # layer_3:
        # layer_4:
        for i in range(n_layers):
            with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                out_channels = ndf/(2*(i+1))
                #stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
                convolved = D_conv(layers[-1], out_channels)
                pooled = D_max_pool(convolved)
                normalized = batchnorm(pooled)
                rectified = lrelu(normalized, 0.1)
                layers.append(rectified)

        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            convolved = D_conv(rectified, out_channels=1)

            fc1 = tf.contrib.layers.flatten(convolved)
#            fc1 = tf.layers.dense(fc1, 300)
#            fc1 = lrelu(fc1)
            # Output layer, class prediction
            out = tf.layers.dense(fc1, 2)
            layers.append(out)
    #         pred_classes = tf.argmax(out, axis=1)
    #         layers.append(pred_classes)
        return layers[-1]

In [3]:
EPS = 1e-12
ngf = 8 #number of generator filters in first conv layer
ndf = 16 #number of discriminator filters in first conv layer
seed = 123
lr = 0.0001 #initial learning rate for adam
beta1 = 0.5 #momentum term of adam"

features = tf.placeholder(np.float32, [4, 106, 128, 110])
qc_labels = tf.placeholder(np.int32, [4])
site_labels = tf.placeholder(np.int32, [4])

with tf.variable_scope("generator"):
    debiased_channels = int(qc_labels.get_shape()[-1])
    debiased = generator(features, debiased_channels)

with tf.variable_scope("qc_discriminator"):
    qc_out = qc_discriminator(debiased)
    x_in = tf.identity(qc_out)

with tf.variable_scope("site_discriminator"):
    site_out = site_discriminator(debiased)
    
qc_D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'qc_discriminator')
site_D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'site_discriminator')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator') 

site_D_solver = tf.train.AdamOptimizer(lr, beta1)
qc_D_solver = tf.train.AdamOptimizer(lr, beta1)
G_solver = tf.train.AdamOptimizer(lr, beta1)

with tf.name_scope("G_loss"):
    G_loss = tf.abs(tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=site_out, labels=site_labels)) - tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=qc_out, labels=qc_labels)))

with tf.name_scope("QC_D_loss"):
    qc_D_loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=qc_out, labels=qc_labels))

with tf.name_scope("Site_D_loss"):
    site_D_loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=site_out, labels=site_labels))
    
site_D_train_step = site_D_solver.minimize(site_D_loss, var_list=site_D_vars)
qc_D_train_step = qc_D_solver.minimize(qc_D_loss, var_list=qc_D_vars)
G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
qc_D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'qc_discriminator')
site_D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'site_discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')

In [4]:
from source_code.data_io import Dataset_Pipeline, _get_data

def run_a_gan(sess, G_train_step, G_loss,\
              qc_D_train_step, qc_D_loss,\
              site_D_train_step, site_D_loss,\
              G_extra_step, qc_D_extra_step, site_D_extra_step,\
              num_epoch=10):
    """Train a GAN for a certain number of epochs.
    
    Inputs:
    - sess: A tf.Session that we want to use to run our data
    - G_train_step: A training step for the Generator
    - G_loss: Generator loss
    - D_train_step: A training step for the Generator
    - D_loss: Discriminator loss
    - G_extra_step: A collection of tf.GraphKeys.UPDATE_OPS for generator
    - D_extra_step: A collection of tf.GraphKeys.UPDATE_OPS for discriminator
    Returns:
        Nothing
    """
    log_dir = "logs"
    current_run_subdir = os.path.join(
        "run_" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    model_dir = os.path.join(log_dir, 'dbGAN', '6012018')

    ds = Dataset_Pipeline(target_shape=(106, 128, 110),
                          n_epochs=10,
                          train_src_folder="/home/smantra/finalproject/data/",
                          train_cache_prefix="/home/smantra/finalproject/cache_train/",
                          eval_src_folder="/home/smantra/finalproject/eval/",
                          eval_cache_prefix="/home/smantra/finalproject/cache_eval/",
                          batch_size=4
                         )
    train_dataset = _get_data(batch_size=ds.batch_size,
                                  src_folder=ds.train_src_folder,
                                  n_epochs=10,
                                  cache_prefix=ds.train_cache_prefix,
                                  shuffle=False,
                                  target_shape=ds.target_shape,
                                 )

    ds_it = train_dataset.make_one_shot_iterator()
    next_batch = ds_it.get_next()
    
    
    print("Starting training!")
    
    for epoch in range(num_epoch):
        for step in range(200):
            #run a batch of data through the network
            feats, (qc_labs, site_labs) = sess.run(next_batch, options=tf.RunOptions(
        report_tensor_allocations_upon_oom=True))
            num_debiased_channels = int(qc_labels.get_shape()[-1])
            feed_dict={features: feats,
                       qc_labels : qc_labs,
                       site_labels : site_labs,                       
                      }
            devices = ['/gpu:0', '/gpu:1']
            for d in devices:
                with tf.device(d):
                    _, G_loss_curr = sess.run([G_train_step, G_loss], feed_dict=feed_dict, options=tf.RunOptions(
        report_tensor_allocations_upon_oom=True))
                    _, qc_D_loss_curr = sess.run([qc_D_train_step, qc_D_loss], feed_dict=feed_dict, options=tf.RunOptions(
        report_tensor_allocations_upon_oom=True))
                    _, site_D_loss_curr = sess.run([site_D_train_step, site_D_loss], feed_dict=feed_dict, options=tf.RunOptions(
        report_tensor_allocations_upon_oom=True))
            if (step % 10 == 0):
                print('Step: {}, qc_D: {:.4}, site_D: {:.4}, G:{:.4}'.format(step,qc_D_loss_curr,site_D_loss_curr,G_loss_curr))
        # Print loss every epoch
        print('Epoch: {}, qc_D: {:.4}, site_D: {:.4}, G:{:.4}'.format(epoch,qc_D_loss_curr,site_D_loss_curr,G_loss_curr))

In [5]:
def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session
devices = ['/gpu:0', '/gpu:1']
d = devices[1]

with get_session() as sess:
    sess.run(tf.global_variables_initializer(), options=tf.RunOptions(
        report_tensor_allocations_upon_oom=True))
    run_a_gan(sess, G_train_step, G_loss,\
              qc_D_train_step, qc_D_loss,\
              site_D_train_step, site_D_loss,\
              G_extra_step, qc_D_extra_step, site_D_extra_step,\
              num_epoch=10)


<TensorSliceDataset shapes: ((),), types: (tf.string,)>
Starting training!


ResourceExhaustedError: OOM when allocating tensor with shape[4,16,214,258,222] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[Node: site_discriminator/site_discriminator/layer_1/conv3d/Conv3D = Conv3D[T=DT_FLOAT, data_format="NDHWC", dilations=[1, 1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](generator/generator/decoder_1/Tanh, site_discriminator/site_discriminator/layer_1/conv3d/kernel/read)]]

Current usage from device: /job:localhost/replica:0/task:0/device:GPU:0, allocator: GPU_0_bfc
  748.11MiB from generator/generator/decoder_1/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from generator/generator/decoder_4/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from gradients_2/generator/generator/decoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  182.19MiB from generator/generator/decoder_4/batch_normalization/batchnorm/mul_1
  92.40MiB from generator/generator/encoder_4/conv3d/Conv3D
  86.38MiB from gradients_2/generator/generator/encoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  86.38MiB from generator/generator/encoder_4/batch_normalization/batchnorm/mul_1
  Remaining 6 nodes with 1.5KiB

	 [[Node: G_loss/Abs/_7 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1009_G_loss/Abs", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Current usage from device: /job:localhost/replica:0/task:0/device:GPU:0, allocator: GPU_0_bfc
  748.11MiB from generator/generator/decoder_1/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from generator/generator/decoder_4/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from gradients_2/generator/generator/decoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  182.19MiB from generator/generator/decoder_4/batch_normalization/batchnorm/mul_1
  92.40MiB from generator/generator/encoder_4/conv3d/Conv3D
  86.38MiB from gradients_2/generator/generator/encoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  86.38MiB from generator/generator/encoder_4/batch_normalization/batchnorm/mul_1
  Remaining 6 nodes with 1.5KiB


Caused by op 'site_discriminator/site_discriminator/layer_1/conv3d/Conv3D', defined at:
  File "/home/shared/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/shared/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/shared/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 478, in start
    self.io_loop.start()
  File "/home/shared/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/shared/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/shared/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-354dea7b2ffa>", line 21, in <module>
    site_out = site_discriminator(debiased)
  File "<ipython-input-2-ee445b97ac89>", line 90, in site_discriminator
    convolved = D_conv(D_input, ndf)
  File "<ipython-input-2-ee445b97ac89>", line 8, in D_conv
    return tf.layers.conv3d(batch_input, out_channels, kernel_size=4, strides=(1, 1, 1), padding="same")
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/convolutional.py", line 828, in conv3d
    return layer.apply(inputs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 828, in apply
    return self.__call__(inputs, *args, **kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 717, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/layers/convolutional.py", line 168, in call
    outputs = self._convolution_op(inputs, self.kernel)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 868, in __call__
    return self.conv_op(inp, filter)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 520, in __call__
    return self.call(inp, filter)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 204, in __call__
    name=self.name)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 1361, in conv3d
    name=name)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[4,16,214,258,222] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[Node: site_discriminator/site_discriminator/layer_1/conv3d/Conv3D = Conv3D[T=DT_FLOAT, data_format="NDHWC", dilations=[1, 1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](generator/generator/decoder_1/Tanh, site_discriminator/site_discriminator/layer_1/conv3d/kernel/read)]]

Current usage from device: /job:localhost/replica:0/task:0/device:GPU:0, allocator: GPU_0_bfc
  748.11MiB from generator/generator/decoder_1/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from generator/generator/decoder_4/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from gradients_2/generator/generator/decoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  182.19MiB from generator/generator/decoder_4/batch_normalization/batchnorm/mul_1
  92.40MiB from generator/generator/encoder_4/conv3d/Conv3D
  86.38MiB from gradients_2/generator/generator/encoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  86.38MiB from generator/generator/encoder_4/batch_normalization/batchnorm/mul_1
  Remaining 6 nodes with 1.5KiB

	 [[Node: G_loss/Abs/_7 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1009_G_loss/Abs", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Current usage from device: /job:localhost/replica:0/task:0/device:GPU:0, allocator: GPU_0_bfc
  748.11MiB from generator/generator/decoder_1/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from generator/generator/decoder_4/conv3d_transpose/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer
  182.19MiB from gradients_2/generator/generator/decoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  182.19MiB from generator/generator/decoder_4/batch_normalization/batchnorm/mul_1
  92.40MiB from generator/generator/encoder_4/conv3d/Conv3D
  86.38MiB from gradients_2/generator/generator/encoder_4/batch_normalization/moments/SquaredDifference_grad/sub
  86.38MiB from generator/generator/encoder_4/batch_normalization/batchnorm/mul_1
  Remaining 6 nodes with 1.5KiB

