# Forward Forward Algorithm

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tensorflow.python.keras import layers, Model
from keras.datasets import cifar10
from keras.optimizers import Adam, SGD
from keras.utils import to_categorical
import numpy as np
from typing import Any

## Label Manipulator

In [None]:
class LabelManipulator:
    """
    A class to handle label manipulation tasks such as generating negative labels
    for contrastive divergance and overlaying labels onto input data.
    """

    def __init__(self, num_classes: int = 10, device: 'str' = 'cpu'):
        """
        Initialize the LabelManipulator Class.

        Args:
            num_classes (int): The number of possible classes. Default value is 10.
            device (str) : The device to perform computation. Default is cpu.
        """
        self.num_classes = num_classes
        self.device = device


    def get_y_neg(self, y: tf.Tensor) -> tf.Tensor:
        """
        Generates Negative Labels for contrastive divergence training.


        Args:
            y (tf.Tensor): A tensor of ground truth labels.

        Returns:
            tf.Tensor: A tensor of negative labels, where each label is replaced
                        with a randomly choosen incorrect label from the same class set.
        """
        # Initialize an empty tensor for the negative labels
        y_neg = tf.identity(y)

        # Generate a negative label for each sample
        for idx, y_samp in enumerate(y):
            # Check if `y_samp` is already a NumPy object
            if isinstance(y_samp, np.ndarray):
                y_samp = y_samp.flatten()[0]  # If NumPy array, directly access its scalar value
            else:
                y_samp = y_samp.numpy().flatten()[0]  # If it's a TensorFlow tensor, convert to NumPy first

            allowed_indices = list(range(self.num_classes))
            allowed_indices.remove(int(y_samp))  # Remove the true label
            negative_label = np.random.choice(allowed_indices)  #Pick a random negative label
            # Reshape the update value to match y_neg's shape
            negative_label = tf.constant([negative_label], dtype=y_neg.dtype)

            y_neg = tf.tensor_scatter_nd_update(y_neg, [[idx]], [negative_label])


        return y_neg


    def overlay_y_on_x(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
        """
        Overlays label information onto the input data tensor by marking
        the label in a certain position of the tensor.

        Args:
            x (tf.Tensor): Input tensor of shape (batch_size, channels, height, width).
            y (tf.Tensor): A tensor of ground truth labels.

        Returns:
            tf.Tensor: A tensor where the label information is overlaid onto the
                        first few positions of each sample in the batch.
        """
        # Clone the input tensor to avoid modifying the original
        x_ = tf.identity(x)

        #Ensure that the tensor is 4-dimensional (batch_size, channes, height, width)
        assert len(x_.shape) == 4, "Input tensor must have 4-dimensions."

        #Zero out the specific area where the label will be overlaid
        for idx, label in enumerate(y):
            # If label has shape (batch_size, 1), extract the scalar
            if len(label.shape) > 0:
                label = int(label.numpy().flatten()[0])  # Extract scalar from Tensor safely

            # Zero out the first few pixels of the first row (set to 0 across all channels)
            x_ = tf.tensor_scatter_nd_update(
                x_, [[idx, 0, i, c] for i in range(self.num_classes) for c in range(x_.shape[-1])],
                tf.zeros([self.num_classes * x_.shape[-1]], dtype=x_.dtype)
            )
            # Set the label in the first row to 1.0 at the correct position
            x_ = tf.tensor_scatter_nd_update(
                x_, [[idx, 0, int(label), c] for c in range(x_.shape[-1])],
                [1.0] * x_.shape[-1]  # Set all channels to 1.0
            )

        return x_

# Custom Fully Connected Layer

In [None]:
from sys import argv


class FullyConnectedLayer(tf.keras.layers.Layer):
    """
    A fully connected (dense) layer that includes L2 normalization. ReLU activation and
    custom training functionality.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True, final_layer: bool = True):
        """
        Initialize the fullyconnected layer.

        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            final_layer (bool): Whether this is the final layer of the network. Deafult is False.
            bias (bool): Whether to include a bias term in this network. Default is True.
            fcl_threshold (float): Threshold value used in custom training function loss.
            num_epochs (int): Number of epochs for custom training. Default is 100.
            log_interval (int): Interval for logging loss during training. Default is 10.
        """

        super(FullyConnectedLayer, self).__init__()
        self.linear = layers.Dense(units = out_features, use_bias = bias)
        self.relu = layers.ReLU()
        self.opt = Adam(learning_rate = args.lr)
        self.fcl_threshold = args.fcl_threshold
        self.num_epochs = args.epochs
        self.final_layer = final_layer

    def forward(self, x: tf.Tensor) -> tf.Tensor:
        """
        Forward Pass through the layer including L2 normalization and ReLU activation.

        Args:
            x (tf.Tensor) : Input Tensor.

        Returns:
            tf.Tensor : Output tensor after applying L2 normalization, linear transformation
                        and ReLU activation.
        """
        # L2 norm across the second axis (feature dimension)
        x_direction = tf.linalg.l2_normalize(x, axis=1)
        return self.relu(self.linear(x_direction))

    def custom_train(self, x_pos: tf.Tensor, x_neg: tf.Tensor):
        """
        Custom training loop for fully-connected layer using positive and negative samples.

        Args:
            x_pos (tf.Tensor) : Positive input data.
            x_neg (tf.Tensor) : Negative input data.

        Returns:
            Tuple[tf.Tensor, tf.Tensor] : Final Positive and Negative samples after training.
        """
        for i in range(self.num_epochs):
            print(f"Models custom fully-connected layer training data on epoch {i}")
            with tf.GradientTape() as tape:
                # calculate the goodness for both positive and negative samples
                g_pos = tf.reduce_mean(tf.square(self.forward(x_pos))) # Mean over all dimensions
                g_neg = tf.reduce_mean(tf.square(self.forward(x_neg))) # Mean over all dimensions

                # calculate custom loss
                loss_pos = -g_pos + self.fcl_threshold
                loss_neg = g_neg - self.fcl_threshold
                loss = tf.reduce_mean(tf.math.log(1 + tf.exp(tf.concat([loss_pos, loss_neg], axis = 0))))

            # Apply Gradients
            gradients = tape.gradient(loss, self.trainable_variables)
            self.opt.apply_gradients(zip(gradients, self.trainable_variables))

            # logging at the specified interval
            if i % self.log_interval == 0:
                print(f"Epoch {i}: Loss = {loss.numpy()}")

        #Return the forward pass outputs after training
        return self.forward(x_pos), self.forward(x_neg)


# Custom Convolution Layer

In [None]:
class Layer(tf.keras.layers.Layer):
    """
    A convolutional layer with L2 normalization, ReLU activation, and custom training functionality.
    """

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
                 padding: str = 'valid', bias: bool = True, final_layer: bool = False,
                 learning_rate: float = 0.001, conv_threshold: float = 1.0, num_epochs: int = 100,
                 log_interval: int = 10):
        """
        Initialize the Layer class.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the convolution kernel. Default is 3.
            stride (int): Stride for the convolution. Default is 1.
            padding (str): Padding method ('valid' or 'same'). Default is 'valid'.
            bias (bool): Whether to include a bias term in the convolution. Default is True.
            final_layer (bool): Whether this is the final layer of the network. Default is False.
            learning_rate (float): Learning rate for the optimizer. Default is 0.001.
            conv_threshold (float): Threshold value used in the custom training loss function. Default is 1.0.
            num_epochs (int): Number of epochs for custom training. Default is 100.
            log_interval (int): Interval for logging and plotting during training. Default is 10.
        """
        super(Layer, self).__init__()
        self.conv = tf.keras.layers.Conv2D(filters=out_channels, kernel_size=kernel_size, strides=stride, padding=padding, use_bias=bias)
        self.relu = layers.ReLU()
        self.opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.conv_threshold = conv_threshold
        self.num_epochs = num_epochs
        self.log_interval = log_interval
        self.final_layer = final_layer

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """
        Forward pass through the convolutional layer with L2 normalization and ReLU activation.

        Args:
            x (tf.Tensor): Input tensor.

        Returns:
            tf.Tensor: Output tensor after the forward pass.
        """
        # L2 norm across the feature dimension
        x_direction = tf.nn.l2_normalize(x, axis=1)
        return self.relu(self.conv(x_direction))

    def custom_train(self, x_pos: tf.Tensor, x_neg: tf.Tensor):
        """
        Custom training loop for the convolutional layer using positive and negative samples.

        Args:
            x_pos (tf.Tensor): Positive input data.
            x_neg (tf.Tensor): Negative input data.

        Returns:
            Tuple[tf.Tensor, tf.Tensor]: Final positive and negative samples after training.
        """
        # Initialize lists to hold loss and goodness values
        loss_values = []
        g_pos_values = []
        g_neg_values = []

        # Initialize the figure for plotting
        fig = plt.figure(figsize=(12, 8))

        for i in range(self.num_epochs):
            print(f"Model custom layer learning data on epoch {i}")
            with tf.GradientTape() as tape:
                # Forward pass for positive and negative samples
                g_pos = tf.reduce_mean(tf.square(self.call(x_pos)), axis=[1, 2, 3])
                g_neg = tf.reduce_mean(tf.square(self.call(x_neg)), axis=[1, 2, 3])

                # Compute custom loss
                loss_pos = -g_pos + self.conv_threshold
                loss_neg = g_neg - self.conv_threshold
                loss = tf.reduce_mean(tf.math.log(1 + tf.exp(tf.concat([loss_pos, loss_neg], axis=0))))

            # Compute gradients and update weights
            gradients = tape.gradient(loss, self.trainable_variables)
            # if any(g is None for g in gradients):
            #     print(f"Layer {self}: Gradients contain None values!")

            # Apply gradients only if they are valid and there are trainable variables
            # if self.trainable_variables:
            self.opt.apply_gradients(zip(gradients, self.trainable_variables))
            # else:
                # print(f"No trainable variables for layer, skipping gradient application.")

            # Log and plot loss and goodness values at intervals
            if i % self.log_interval == 0:
                loss_values.append(loss.numpy())
                g_pos_values.append(tf.reduce_mean(g_pos).numpy())
                g_neg_values.append(tf.reduce_mean(g_neg).numpy())

                # Plotting
                plt.subplot(3, 1, 1)
                plt.plot(loss_values, color='blue')
                plt.title("Loss during training")

                plt.subplot(3, 1, 2)
                plt.plot(g_pos_values, color='green')
                plt.title("g_pos during training")

                plt.subplot(3, 1, 3)
                plt.plot(g_neg_values, color='red')
                plt.title("g_neg during training")

                plt.tight_layout()
                clear_output(wait=True)  # Clear the output to update the plot
                plt.show()

            # Print the loss at each step
            print(f'Loss at step {i}: {loss.numpy()}')

        # Return the final outputs after training
        return self.call(x_pos), self.call(x_neg)

## Custom Network Class

In [None]:
class Net(tf.keras.Model):
    """
    A neural network model that includes multiple convolutional layers, ReLU activations,
    and custom training functionality.
    Inherits from `tf.keras.Model`.
    """

    def __init__(self, num_classes: int = 10):
        """
        Initialize the Net model with convolutional layers, fully connected layers,
        and ReLU activations.

        Args:
            num_classes (int): Number of output classes. Default is 10.
        """
        super(Net, self).__init__()
        self.num_classes = num_classes

        # Instantiate the LabelManipulator
        self.label_manipulator = LabelManipulator(num_classes=num_classes)

        # Define the layers
        self.layers_ = [
            # Block 1
            Layer(3, 64, kernel_size=3, padding='same'),  # Conv 1
            tf.keras.layers.ReLU(),
            Layer(64, 64, kernel_size=3, padding='same'),  # Conv 2
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),  # MaxPooling

            # Block 2
            Layer(64, 128, kernel_size=3, padding='same'),  # Conv 3
            tf.keras.layers.ReLU(),
            Layer(128, 128, kernel_size=3, padding='same'),  # Conv 4
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),  # MaxPooling

            # Block 3
            Layer(128, 256, kernel_size=3, padding='same'),  # Conv 5
            tf.keras.layers.ReLU(),
            Layer(256, 256, kernel_size=3, padding='same'),  # Conv 6
            tf.keras.layers.ReLU(),
            Layer(256, 256, kernel_size=3, padding='same'),  # Conv 7
            tf.keras.layers.ReLU(),
            Layer(256, 256, kernel_size=3, padding='same'),  # Conv 8
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),  # MaxPooling

            # Block 4
            Layer(256, 512, kernel_size=3, padding='same'),  # Conv 9
            tf.keras.layers.ReLU(),
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 10
            tf.keras.layers.ReLU(),
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 11
            tf.keras.layers.ReLU(),
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 12
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),  # MaxPooling

            # Block 5
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 13
            tf.keras.layers.ReLU(),
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 14
            tf.keras.layers.ReLU(),
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 15
            tf.keras.layers.ReLU(),
            Layer(512, 512, kernel_size=3, padding='same'),  # Conv 16
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),  # MaxPooling

            # Fully Connected Layers
            tf.keras.layers.Flatten(),  # Flatten the input for Dense layers
            FullyConnectedLayer(512 * 7 * 7, 4096),  # Dense 1 (4096 units)
            tf.keras.layers.ReLU(),
            FullyConnectedLayer(4096, 4096),  # Dense 2 (4096 units)
            tf.keras.layers.ReLU(),
            FullyConnectedLayer(4096, 10, final_layer=True),  # Output layer (10 classes for CIFAR-10)
        ]

      #[
        #     Layer(3, 64, kernel_size=5, padding='same'),  # Layer 0
        #     tf.keras.layers.ReLU(),

        #     Layer(64, 128, kernel_size=5, padding='same', stride=2),  # Layer 2
        #     tf.keras.layers.ReLU(),

        #     Layer(128, 256, kernel_size=5, padding='same', stride=2),  # Layer 4
        #     tf.keras.layers.ReLU(),

        #     Layer(256, 512, kernel_size=5, padding='same', stride=2),  # Layer 6
        #     tf.keras.layers.ReLU(),

        #     Layer(512, 1024, kernel_size=5, padding='same', stride=2),  # Layer 8
        #     tf.keras.layers.ReLU(),

        #     Layer(1024, 2048, kernel_size=5, padding='same', stride=2),  # Layer 10
        #     tf.keras.layers.ReLU(),

        #     FullyConnectedLayer(2048, 10, final_layer=True),  # Layer 12
        # ]

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """
        Forward pass through the network.

        Args:
            x (tf.Tensor): Input tensor.

        Returns:
            tf.Tensor: Output tensor after passing through all layers.
        """
        for layer in self.layers_:
            x = layer(x)
        return x

    def predict(self, x: tf.Tensor) -> tf.Tensor:
        """
        Make predictions by overlaying labels and calculating goodness for each label.

        Args:
            x (tf.Tensor): Input tensor.

        Returns:
            tf.Tensor: Index of the label with the highest goodness score.
        """
        goodness_per_label = []
        for label in range(self.num_classes):
            # Use the overlay_y_on_x method from the LabelManipulator class
            h = self.label_manipulator.overlay_y_on_x(x, tf.constant([label]))
            goodness = []
            for layer in self.layers_:
                h = layer(h)
                goodness.append(tf.reduce_sum(tf.square(h)) / tf.size(h, out_type=tf.float32))
            goodness_per_label.append(tf.reduce_sum(tf.stack(goodness)))

        return tf.argmax(tf.stack(goodness_per_label))

    def custom_train(self, x_pos: tf.Tensor, x_neg: tf.Tensor):
        """
        Custom training logic, training each layer separately if it belongs to specific classes.

        Args:
            x_pos (tf.Tensor): Positive input data.
            x_neg (tf.Tensor): Negative input data.
        """
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers_):
            print(f"Layer {i} trainable variables: {layer.trainable_variables}")
            print(f"Training layer: {i}")
            if isinstance(layer, Layer):  # Call custom_train on instances of the Layer class
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            elif isinstance(layer, FullyConnectedLayer):  # Call custom_train on instances of FullyConnectedLayer
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            else:  # For other layers, just pass the data through
                h_pos = layer(h_pos)
                h_neg = layer(h_neg)

# Network hyperparameter setup

In [None]:
# class Args:
#     """
#     A class to store hyperparameters and configuration settings.
#     """
#     train_size = 1000  # Default training size
#     test_size = 100  # Default testing size
#     epochs = 1000  # Number of epochs for training
#     lr = 0.05  # Learning rate
#     no_cuda = False  # Whether to disable CUDA (GPU acceleration)
#     no_mps = False  # Whether to disable MPS (Apple's GPU acceleration)
#     save_model = False  # Whether to save the trained model
#     fcl_threshold = 1  # Threshold for FullyConnectedLayer
#     conv_threshold = 0.02  # Threshold for convolutional layers
#     seed = 1234  # Random seed for reproducibility
#     log_interval = 10  # How often to log training information

class Args:
    """
    A class to store hyperparameters and configuration settings.
    """
    train_size = 50000  # Full CIFAR-10 training set size
    test_size = 10000  # Full CIFAR-10 test set size
    epochs = 200  # Number of epochs for training
    lr = 0.001  # Learning rate for stable training (adjust if needed)
    no_cuda = False  # Whether to disable CUDA (GPU acceleration)
    no_mps = False  # Whether to disable MPS (Apple's GPU acceleration)
    save_model = True  # Whether to save the trained model
    fcl_threshold = 1  # Threshold for FullyConnectedLayer
    conv_threshold = 0.02  # Threshold for convolutional layers
    seed = 1234  # Random seed for reproducibility
    log_interval = 10  # Log progress every 10 batches

# Create an instance of the Args class
args = Args()

# Set the random seed for TensorFlow
tf.random.set_seed(args.seed)

# Determine whether to use GPU or CPU
use_cuda = not args.no_cuda and tf.config.list_physical_devices('GPU')
use_mps = not args.no_mps and tf.config.list_physical_devices('MPS')

if use_cuda:
    device = "GPU"
elif use_mps:
    device = "MPS"
else:
    device = "CPU"

print(f"Using device: {device}")

# Set batch size configurations
train_kwargs = {"batch_size": args.train_size}
test_kwargs = {"batch_size": args.test_size}

# If GPU (or MPS) is available, configure additional options for performance optimization
if use_cuda or use_mps:
    additional_kwargs = {"shuffle": True}
    train_kwargs.update(additional_kwargs)
    test_kwargs.update(additional_kwargs)

# Example use of the configurations
print(f"Training configuration: {train_kwargs}")
print(f"Testing configuration: {test_kwargs}")

# Data Loading

In [None]:
def normalize_images(image, mean, std):
    image = tf.cast(image, tf.float32) / 255.0  # Convert to float32 and scale to [0, 1]
    image = (image - mean) / std  # Normalize
    return image


# Load CIFAR-10 data using TensorFlow
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize mean and std
mean = tf.constant([0.4914, 0.4822, 0.4465])
std = tf.constant([0.2023, 0.1994, 0.2010])

# Data preprocessing: Normalize the images and convert labels to one-hot encoding if needed
train_images = normalize_images(train_images, mean, std)
test_images = normalize_images(test_images, mean, std)

# Optional: One-hot encode labels if needed (comment this out if not required)
# train_labels = to_categorical(train_labels, num_classes=10)
# test_labels = to_categorical(test_labels, num_classes=10)

# Define TensorFlow datasets for training and testing
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

# Batch the datasets and apply any additional settings from config
train_dataset = train_dataset.shuffle(buffer_size=50000).batch(64)     #.batch(args.train_size)  # Shuffle and batch the training data
test_dataset = test_dataset.batch(64)     #.batch(args.test_size)  # Only batch the testing data

# Prefetching to improve performance
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


# Initialize and train network

In [None]:
# Instantiate the model and move it to the correct device (if GPU or MPS is available)
net = Net()
device = "GPU" if tf.config.list_physical_devices('GPU') else "CPU"
print(f"Model will run on: {device}")

In [None]:
# Load a batch of data (using TensorFlow dataset)
for x, y in train_dataset.take(1):
    pass

In [None]:
# Overlay the labels on the input data
label_manipulator = LabelManipulator(num_classes=10)
x_pos = label_manipulator.overlay_y_on_x(x, y)
y_neg = label_manipulator.get_y_neg(y)
x_neg = label_manipulator.overlay_y_on_x(x, y_neg)

# Inspect the tensor shapes
print(f"x shape: {x.shape}")
print(f"Batch size: {x.shape[0]}")
print(f"y shape: {y.shape}")

In [None]:
# Visualize samples
fig, axs = plt.subplots(5, 3, figsize=(10, 10))

# Define a dictionary to map class indices to class names
class_dict = {i: 'class_' + str(i) for i in range(10)}

for i in range(5):
    # Convert TensorFlow tensors to numpy arrays and transpose for correct plotting
    img = x[i].numpy()
    pos_img = x_pos[i].numpy()
    neg_img = x_neg[i].numpy()

    # Display original image
    axs[i, 0].imshow(img, interpolation = 'nearest')
    axs[i, 0].set_title('Original: ' + class_dict[int(y[i])] + '\n Shape: ' + str(img.shape))

    # Display positive (overlaid label) image
    axs[i, 1].imshow(pos_img, interpolation = 'nearest')
    axs[i, 1].set_title('Positive: ' + class_dict[int(y[i])] + '\n Shape: ' + str(pos_img.shape))

    # Display negative image
    axs[i, 2].imshow(neg_img, interpolation = 'nearest')
    axs[i, 2].set_title('Negative: ' + class_dict[int(y_neg[i])] + '\n Shape: ' + str(neg_img.shape))

# Remove axis ticks
for ax in axs.flat:
    ax.axis('off')

# Adjust layout for better visualization
plt.tight_layout()
plt.show()

# Train the network

In [None]:
net.custom_train(x_pos, x_neg)