In [41]:
# Third-party imports.
import numpy as np
import matplotlib.pyplot as plt

# Tensorflow imports,
import tensorflow as tf
from tensorflow import keras

# Colab/Jupyter imports,
from IPython.display import clear_output
import ipywidgets as widgets
from IPython.display import display

# Loading Data

In [None]:
def parse_example(example_proto):
    feature_description = {
        "X_data": tf.io.VarLenFeature(tf.float32),
        "Y_data": tf.io.VarLenFeature(tf.float32),
    }
    example = tf.io.parse_single_example(example_proto, feature_description)

    # Convert sparse tensors to dense tensors,
    x = tf.sparse.to_dense(example["X_data"])
    y = tf.sparse.to_dense(example["Y_data"])

    # Shaping tensors,
    X_height, X_width, X_length, X_channels = 35, 35, 35, 1
    Y_height, Y_width, Y_length, Y_channels = 9, 9, 9, 1
    x = tf.reshape(x, (X_height, X_width, X_length, X_channels))
    y = tf.reshape(y, (Y_height, Y_width, Y_length, Y_channels))

    return x, y

def load_tfrecord(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(parse_example)  
    dataset = dataset.prefetch(tf.data.AUTOTUNE)  
    return dataset

train_dataset = load_tfrecord("training_dataset.tfrecord")
val_dataset = load_tfrecord("testing_dataset.tfrecord")
test_dataset = load_tfrecord("validation_dataset.tfrecord")

# Model Architecture & Helper Functions

In [30]:
class Net(keras.Model):
    def __init__(self, input_shape=(None, 35, 35, 35, 1)):
        super(Net, self).__init__()

        # Input layer,
        self.input_layer = keras.layers.Input(shape=input_shape)

        # Convolutional layers,
        self.pad = keras.layers.ZeroPadding3D(padding = ((0, 1), (0, 1), (0, 1)))
        self.conv1 = keras.layers.Conv3D(32, (3, 3, 3), padding="same", activation="relu")
        self.pool1 = keras.layers.MaxPooling3D((2, 2, 2))
        self.conv2 = keras.layers.Conv3D(64, (3, 3, 3), padding="same", activation="relu")
        self.pool2 = keras.layers.MaxPooling3D((2, 2, 2))
        self.conv3 = keras.layers.Conv3D(128, (3, 3, 3), padding="same", activation="relu")
        self.conv4 = keras.layers.Conv3D(1, (3, 3, 3), padding="same", activation="relu")

    def call(self, inputs):

        # Forward pass (convolutional layers),
        x = self.pad(inputs)
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        return x

def compute_validation_loss(validation_data, batch_size):
    """Computes the loss on the validation dataset. This is done at the end of every epoch."""

    # Creating dataloader for validation dataset,
    val_dataset = validation_data.shuffle(buffer_size=10000).batch(BATCH_SIZE)

    # Creating variable to sum up batch losses,
    accumulated_loss = 0

    for step, (X_batch, Y_batch) in enumerate(val_dataset):
        # Model forward pass,
        Y_pred_val = model(X_batch, training=False)

        # Computing batch loss,
        batch_val_loss = tf.reduce_mean(keras.losses.MAE(Y_batch, Y_pred_val))

        # Accumlating loss,
        accumulated_loss += batch_val_loss.numpy()

    # Computing total loss,
    val_loss = accumulated_loss/step

    return val_loss

def model_predict_sample(model, testing_data, i):
    """This function is allows the model to make a prediction on the ith sample of the testing dataset only """

    # Pulling out the sample,
    sample = list(testing_data.skip(i).take(1))

    # Extracing X from (X,Y) pair in the sample,
    for (X,Y) in sample:

      # Reformatting and shaping the data tensor,
      data = np.squeeze(X.numpy(), axis=-1)

      # Reformatting and shaping the target tensor,
      target = np.squeeze(Y.numpy(), axis=-1)

      # Adding batch dimension (batch_size = 1),
      X = tf.expand_dims(X, axis=0)

      # Model forward pass,
      pred = model(X, training=False)

      # Reformatting and shaping the prediction tensor,
      pred = tf.reshape(pred, (9, 9, 9)).numpy()

      return data, pred, target

# View Model

In [31]:
# Creating model instance,
model = Net()

# Passing random tensor into model (TF uses channel first scheme),
model(tf.random.normal((1, 35, 35, 35, 1))) 

# Print model summary
model.summary()

# Train Model

In [48]:
"""HYPERPARAMERS"""
EPOCHS = 10
LEARNING_RATE = 0.001
BATCH_SIZE = 256

# Creating dataloader,
train_dataset = train_dataset.shuffle(buffer_size=10000).batch(BATCH_SIZE)

# Creating optimiser,
optimiser = keras.optimizers.Adam(learning_rate=LEARNING_RATE)

# Tracking validation loss and epoch,
val_losses = []
epochs = []

# Training loop (epoch level),
for epoch in range(1, (EPOCHS + 1)):

    # Training loop (batch_level),
    for step, (X_batch, Y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:

            # Model forward pass,
            Y_pred = model(X_batch, training=True)

            # Computing loss,
            loss = tf.reduce_mean(keras.losses.MAE(Y_batch, Y_pred))

        # Computing gradients,
        gradients = tape.gradient(loss, model.trainable_variables)

        # Updating weights,
        optimiser.apply_gradients(zip(gradients, model.trainable_variables))

    # Computing validation loss,
    val_loss = compute_validation_loss(val_dataset, BATCH_SIZE)
    val_losses.append(val_loss)

    # Tracking epoch,
    epochs.append(epoch)

    """GRAPHING"""

    # Clear previous input,
    clear_output(wait=True)

    # Model prediction on sample,
    data, pred, target = model_predict_sample(model, test_dataset, 0)

    # Creating figures,
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    # Plot titles,
    titles = ["Input Data", "Model Prediction", "Ground Truth"]

    # Plotting each image,
    for ax, img, title in zip(axes[-3:], [data, pred, target], titles):
      ax.imshow(img, cmap = "viridis")
      ax.set_title(title)
      ax.axis("off")

    # Plotting training progress,
    axes[0].plot(epochs, val_losses, label="Validation Loss", color="black", marker='o')
    axes[0].set_xlabel("Epochs")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Training Progess")
    axes[0].grid()

    # Displaying plot,
    plt.show()

    # Print statement,
    print(f"Epoch: {epoch}, Training Loss: {loss}, Validation Loss: {val_loss}")

KeyboardInterrupt: 