In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os


import tensorflow as tf
import nibabel as nib

from tqdm import tqdm
from tfrecord_utils import *

tf.enable_eager_execution()

  from ._conv import register_converters as _register_converters


In [2]:
TF_RECORD_FILENAME = "/home-local/remedis/dataset.tfrecords"

In [3]:
dataset = tf.data.TFRecordDataset(TF_RECORD_FILENAME).map(lambda record : parse_bag(record, (64, 64), 1))
iterator = dataset.make_one_shot_iterator()

Instructions for updating:
Colocations handled automatically by placer.


In [4]:
# model
inputs = tf.keras.layers.Input(shape=(None, None, 1))
x = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(inputs)
x = tf.keras.layers.MaxPooling2D(2, 2)(x)
x = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(x)
x = tf.keras.layers.MaxPooling2D(2, 2)(x)
x = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(x)
x = tf.keras.layers.GlobalMaxPooling2D()(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

In [5]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, None, None, 1)     0         
_________________________________________________________________
conv2d (Conv2D)              (None, None, None, 32)    320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, None, None, 32)    0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, None, None, 32)    9248      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, None, None, 32)    0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, None, None, 32)    9248      
_________________________________________________________________
global_max_pooling2d (Global (None, 32)                0         
__________

In [None]:
def loss_fn(model, x, y):
    return tf.losses.sigmoid_cross_entropy(y, model(x, training=True), reduction=tf.losses.Reduction.NONE)

def grad(model, x, y):
    instance_losses = []
    with tf.GradientTape() as tape:
        for x_instance in x:
            instance_losses.append(
                loss_fn(model, 
                        tf.reshape(x_instance, (1,) + x_instance.numpy().shape), 
                        tf.reshape(y, (1,) + y.numpy().shape)))
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

In [None]:
# training
N_EPOCHS = 5
progbar = tf.keras.utils.Progbar(target=61)

epoch_loss_avg = tf.contrib.eager.metrics.Mean()
epoch_accuracy = tf.contrib.eager.metrics.Accuracy()

In [None]:
next_element = iterator.next()

In [None]:
loss_value, grads = grad(model, x, y)

In [None]:


for cur_epoch in range(N_EPOCHS):
    print("\nEpoch {}/{}".format(cur_epoch + 1, N_EPOCHS))
    progbar.update(0)
    for cur_batch_dataset, (x, y) in enumerate(dataset):
        loss_value, grads = grad(model, x, y)
        opt.apply_gradients(zip(grads, model.trainable_variables))
        
        epoch_loss_avg(loss_value)
        epoch_accuracy(tf.argmax(model(x), axis=1, output_type=tf.int32), y)
        
        
        
    print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(cur_epoch,
                                                                epoch_loss_avg.result(),
                                                                epoch_accuracy.result()))

In [None]:
# read and print
for i, (a, b) in enumerate(dataset.take(1)):
    print(a.shape)
    plt.figure()
    plt.imshow(a[200,:,:,0].numpy().astype(np.int32))
    plt.title(b.numpy())