In [64]:
# Third-party imports.
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# Tensorflow imports,
import tensorflow as tf
from tensorflow import keras
from keras.datasets import mnist
from keras.utils import to_categorical

# Loading Data

In [65]:
def parse_example(example_proto):
    feature_description = {
        "X_data": tf.io.VarLenFeature(tf.float32),
        "Y_data": tf.io.VarLenFeature(tf.float32),
    }
    example = tf.io.parse_single_example(example_proto, feature_description)

    # Convert sparse tensors to dense tensors
    x = tf.sparse.to_dense(example["X_data"])
    y = tf.sparse.to_dense(example["Y_data"])

    # Automatically infer and reshape (height, width, channels)
    height, width, channels = 256, 256, 1  # Replace with your actual values
    x = tf.reshape(x, (height, width, channels))
    y = tf.reshape(y, (height, width, channels))

    return x, y

def load_tfrecord(filename, batch_size=32):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(parse_example)  # Efficient parsing 
    dataset = dataset.prefetch(tf.data.AUTOTUNE)  # Optimize performance
    return dataset

train_dataset = load_tfrecord("training_dataset.tfrecord")
val_dataset = load_tfrecord("testing_dataset.tfrecord")
test_dataset = load_tfrecord("validation_dataset.tfrecord")

# Model Architecture & Helper Functions

In [67]:
class Net(keras.Model):
    def __init__(self, input_shape=(None, 256, 256, 1)):
        super(Net, self).__init__()

        # Input layer,
        self.input_layer = keras.layers.Input(shape=input_shape)

        # Encoder,
        self.conv1 = keras.layers.Conv2D(32, (3, 3), padding="same", activation="relu")
        self.conv2 = keras.layers.Conv2D(32, (3, 3), padding="same", activation="relu")
        self.pool1 = keras.layers.MaxPooling2D((2, 2))

        self.conv3 = keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu")
        self.conv4 = keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu")
        self.pool2 = keras.layers.MaxPooling2D((2, 2))

        # Bottleneck,
        self.conv5 = keras.layers.Conv2D(128, (3, 3), padding="same", activation="relu")

        # Decoder,
        self.upconv1 = keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same")
        self.conv6 = keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu")
        self.conv7 = keras.layers.Conv2D(64, (3, 3), padding="same", activation="relu")

        self.upconv2 = keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding="same")
        self.conv8 = keras.layers.Conv2D(32, (3, 3), padding="same", activation="relu")
        self.conv9 = keras.layers.Conv2D(32, (3, 3), padding="same", activation="relu")

        # Output layer,
        self.output_layer = keras.layers.Conv2D(1, (1, 1), activation="sigmoid")
    
    def call(self, inputs):
    
        # Encoder forward pass,
        x1 = self.conv1(inputs)
        x2 = self.conv2(x1)
        x3 = self.pool1(x2)
    
        x4 = self.conv3(x3)
        x5 = self.conv4(x4)
        x6 = self.pool2(x5)
    
        # Bottleneck forward pass,
        x7 = self.conv5(x6)
    
        # Decoder forward pass,
        x8 = self.upconv1(x7)
        x9 = keras.layers.Concatenate()([x8, x5])  # Skip connection
        x10 = self.conv6(x9)
        x11 = self.conv7(x10)
    
        x12 = self.upconv2(x11)
        x13 = keras.layers.Concatenate()([x12, x1])  # Skip connection
        x14 = self.conv8(x13)
        x15 = self.conv9(x14)
    
        return self.output_layer(x15)
    
def compute_validation_loss(validation_data, batch_size):
    """Computes the loss on the validation dataset. This is done at the end of every epoch."""

    # Loading validation dataset,
    val_dataset.batch(batch_size)

    # Creating variable to sum up batch losses,
    accumulated_loss = 0

    for step, (X_batch, Y_batch) in enumerate(val_dataset):

        # Model forward pass,
        Y_pred_val = model(X_batch, training=False)

        # Computing batch loss,
        batch_val_loss = tf.reduce_mean(keras.losses.MAE(Y_batch, Y_pred_val))

        # Accumlating loss,
        accumulated_loss += batch_val_loss.numpy()

    # Computing total loss,
    val_loss = accumulated_loss/step

    return val_loss

# View Model

In [68]:
# Creating model instance,
model = Net()

# Passing random tensor into model (TF uses channel first scheme),
model(tf.random.normal((1, 256, 256, 1))) 

# Print model summary
model.summary()

# Train Model

In [69]:
"""HYPERPARAMERS"""
EPOCHS = 1
LEARNING_RATE = 0.001
BATCH_SIZE = 32

# Creating dataloader,
train_dataset = train_dataset.shuffle(buffer_size=10000).batch(BATCH_SIZE)

# Creating optimiser,
optimiser = keras.optimizers.Adam(learning_rate=LEARNING_RATE)

# Training loop (epoch level),
for epoch in range(1, (EPOCHS + 1)):

    # Training loop (batch_level),
    for step, (X_batch, Y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:

            # Model forward pass,
            Y_pred = model(X_batch, training=True)

            # Computing loss,
            loss = tf.reduce_mean(keras.losses.MAE(Y_batch, Y_pred))

        # Computing gradients,
        gradients = tape.gradient(loss, model.trainable_variables)

        # Updating weights,
        optimiser.apply_gradients(zip(gradients, model.trainable_variables))

    # Computing validation loss,
    val_loss = compute_validation_loss(val_dataset, BATCH_SIZE)

    print(f"Epoch: {epoch}, Training Loss: {loss}, Validation Loss: {val_loss}")

KeyboardInterrupt: 