In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from tqdm import tqdm
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda, RandomHorizontalFlip, RandomCrop, ColorJitter
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
# Data Loader
class DataLoaderManager:
  """
  DataLoader manager class to handle data loading for various datasets.
  """
  def __init__(self, dataset_name, train_batch_size=10000, test_batch_size=1000):
        """
        Initializes the DataLoaderManager with specified dataset and batch sizes.

        Args:
        dataset_name (str): Name of the dataset ('MNIST', 'CIFAR10', 'FashionMNIST').
        train_batch_size (int): Batch size for the training dataset.
        test_batch_size (int): Batch size for the test dataset.
        """
        self.dataset_name = dataset_name
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.transform = self.get_transforms()

  def get_transforms(self):
        """
        Returns the appropriate transformation for each dataset.

        Returns:
        torchvision.transforms.Compose: Transformation pipeline.
        """
        if self.dataset_name == 'CIFAR10':
            return Compose([
                # added image augmentation techniques in additional to normalizing
                RandomHorizontalFlip(),
                RandomCrop(32, padding=4),
                ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                ToTensor(),
                Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
        else:  # Default to MNIST and FashionMNIST normalization
            return Compose([
                ToTensor(),
                Normalize((0.1307,), (0.3081,)),
                Lambda(lambda x: torch.flatten(x))
            ])

  def load_dataset(self, train=True):
        """
        Loads the specified dataset.

        Args:
        train (bool): If True, load training dataset; otherwise, load test dataset.

        Returns:
        Dataset: The requested dataset.
        """
        if self.dataset_name == 'MNIST':
            return MNIST('./data/', train=train, download=True, transform=self.transform)
        elif self.dataset_name == 'CIFAR10':
            return CIFAR10('./data/', train=train, download=True, transform=self.transform)
        elif self.dataset_name == 'FashionMNIST':
            return FashionMNIST('./data/', train=train, download=True, transform=self.transform)

  def get_data_loaders(self):
        """
        Creates and returns data loaders for training and testing datasets.

        Returns:
        tuple: Tuple containing the training and testing data loaders.
        """
        train_loader = DataLoader(self.load_dataset(train=True),
                                  batch_size=self.train_batch_size, shuffle=True)
        test_loader = DataLoader(self.load_dataset(train=False),
                                 batch_size=self.test_batch_size, shuffle=False)
        return train_loader, test_loader

In [None]:
class NegativeDataGenerator:
    """
    A class to generate negative data using different methods:
    - Long-Range vs Short-Range Correlations
    - Hybrid Data
    - Random Noise Corruption

    This class is suitable for datasets like MNIST and CIFAR-10.
    """

    def __init__(self, num_classes=10, noise_level=0.2):
        """
        Initialize the NegativeDataGenerator class with default parameters.

        Args:
            num_classes (int): The number of classes in the dataset.
            noise_level (float): The level of random noise to be added in random noise corruption.
        """
        self.num_classes = num_classes
        self.noise_level = noise_level

    def _create_hybrid_images(self, x):
        """
        Create negative examples by combining two images using a mask, preserving short-range correlations
        but altering long-range correlations.

        Args:
            x (torch.Tensor): Batch of images (batch_size, channels, height, width).

        Returns:
            torch.Tensor: A batch of hybrid negative images.
        """
        batch_size, channels, height, width = x.shape
        x_neg = x.clone()

        # Generate a mask that has the same dimensions as each image in x
        mask = torch.rand((1, height, width), device=x.device)  # Create a mask with only height and width dimensions
        mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1)  # Use avg_pool2d for smoothing
        mask = (mask > 0.5).float().expand(batch_size, channels, height, width)  # Expand mask across batch and channels

        for i in range(batch_size):
            # Randomly pick another image in the batch
            j = random.choice([idx for idx in range(batch_size) if idx != i])

            # Create hybrid image by combining two images with mask and reverse of the mask
            x_neg[i] = x[i] * mask[i] + x[j] * (1 - mask[i])

        return x_neg

    def _label_swapped_images(self, x, y):
        """
        Generate negative examples by overlaying incorrect labels on the images.

        Args:
            x (torch.Tensor): Batch of images (batch_size, channels, height, width).
            y (torch.Tensor): True labels corresponding to the images.

        Returns:
            torch.Tensor: A batch of images with incorrect label overlay.
        """
        y_neg = self._get_y_neg(y)  # Generate negative labels
        x_neg = self._overlay_y_on_x(x, y_neg)  # Apply incorrect labels as overlays
        return x_neg

    def _random_noise_corruption(self, x):
        """
        Generate negative examples by adding random noise to the images.

        Args:
            x (torch.Tensor): Batch of images (batch_size, channels, height, width).

        Returns:
            torch.Tensor: A batch of corrupted images.
        """
        x_neg = x.clone()
        noise = torch.rand_like(x_neg)  # Generate random noise

        # Apply noise by masking a random fraction of pixels
        mask = (torch.rand_like(x_neg) < self.noise_level).float()  # Mask to decide where to add noise
        x_neg = x_neg * (1 - mask) + noise * mask  # Combine original and noise based on mask
        return x_neg

    def _overlay_y_on_x(self, x, y):
        """
        Overlay incorrect label information onto images.

        Args:
            x (torch.Tensor): Batch of images (batch_size, channels, height, width).
            y (torch.Tensor): Incorrect labels to overlay on images.

        Returns:
            torch.Tensor: Images with overlay at the position based on labels.
        """
        x_ = x.clone()  # Clone the tensor to avoid modifying the original in place.
        max_value = x.max()  # Get the maximum value from the entire tensor.
        x_[range(x.shape[0]), :, 0, :self.num_classes] *= 0.0  # Zero out the first 'classes' pixels in the width.
        x_[range(x.shape[0]), :, 0, y] = max_value  # Set the pixel at the label index to the maximum value.
        return x_

    def _get_y_neg(self, y):
        """
        Generates negative labels by ensuring each label is different from the original.

        Args:
            y (torch.Tensor): A tensor containing the original labels.

        Returns:
            torch.Tensor: A tensor containing negative labels.
        """
        y_neg = y.clone()
        for idx, y_samp in enumerate(y):
            allowed_indices = list(range(self.num_classes))
            allowed_indices.remove(y_samp.item())
            y_neg[idx] = torch.tensor(allowed_indices)[torch.randint(len(allowed_indices), size=(1,))].item()
        return y_neg

    def generate(self, x, y=None, method="long_range"):
        """
        Generate negative data based on the specified method.

        Args:
            x (torch.Tensor): Batch of images (batch_size, channels, height, width).
            y (torch.Tensor, optional): Labels corresponding to the images (required for 'label_swap' method).
            method (str): The method for generating negative data ('long_range', 'label_swap', 'random_noise').

        Returns:
            torch.Tensor: A batch of negative images.
        """
        if method == "long_range":
            return self._create_hybrid_images(x)
        elif method == "label_swap":
            if y is None:
                raise ValueError("Labels (y) are required for the 'label_swap' method.")
            return self._label_swapped_images(x, y)
        elif method == "random_noise":
            return self._random_noise_corruption(x)
        else:
            raise ValueError(f"Invalid method '{method}'. Choose from 'long_range', 'label_swap', or 'random_noise'.")

In [None]:
# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# Initialize data loaders
data_manager = DataLoaderManager("CIFAR10")
train_loader, test_loader = data_manager.get_data_loaders()

In [None]:
# Fetch a single batch of data
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)  # Move data to the device

In [None]:
x.shape

In [None]:
y.shape

In [None]:
def display_samples(x, x_neg, num_samples=5):
    """
    Display a few samples of original and generated negative data side by side.

    Args:
        x (torch.Tensor): Batch of original images.
        x_neg (torch.Tensor): Batch of generated negative images.
        num_samples (int): Number of samples to display.
    """
    # Move data to CPU and detach from computation graph if necessary
    x = x[:num_samples].cpu().detach()
    x_neg = x_neg[:num_samples].cpu().detach()

    fig, axes = plt.subplots(2, num_samples, figsize=(15, 4))
    for i in range(num_samples):
        # Display original images
        axes[0, i].imshow(x[i].permute(1, 2, 0).squeeze(), cmap='gray')
        axes[0, i].axis('off')
        axes[0, i].set_title("Original")

        # Display negative images
        axes[1, i].imshow(x_neg[i].permute(1, 2, 0).squeeze(), cmap='gray')
        axes[1, i].axis('off')
        axes[1, i].set_title("Negative")

    plt.show()

In [None]:
# Initialize the NegativeDataGenerator
negative_data_generator = NegativeDataGenerator(num_classes=10, noise_level=0.2)

# Generate negative data using "long_range" method as an example
x_neg = negative_data_generator.generate(x, y, method="label_swap")

# Display original and negative samples
display_samples(x, x_neg, num_samples=5)

In [None]:
# # Generate negative data using "long_range" method as an example
# x_neg_1 = negative_data_generator.generate(x, y, method="random_noise")

# # Display original and negative samples
# display_samples(x, x_neg_1, num_samples=5)

In [None]:
class PeerNormalization(nn.Module):
    """
    A class for peer normalization, which normalizes the activity in a layer
    to maintain a stable distribution of neuron activations across channels.

    This normalization stabilizes neuron activity by adjusting each neuron's output
    relative to the mean activity across all neurons in the layer.
    """

    def __init__(self):
        super(PeerNormalization, self).__init__()

    def forward(self, x):
        """
        Forward pass for peer normalization.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Normalized tensor.
        """
        if x.dim() == 4:
            # 4D Tensor: Normalize across channels, height, and width
            mean_activity = x.mean(dim=(0, 2, 3), keepdim=True)
        elif x.dim() == 2:
            # 2D Tensor: Normalize across features
            mean_activity = x.mean(dim=0, keepdim=True)
        else:
            raise ValueError(f"Unsupported tensor dimension: {x.dim()}")

        # mean_activity = x.mean(dim=(0, 2, 3), keepdim=True)  # Mean across channels, height, and width
        x = x - mean_activity  # Subtract mean activity
        overall_mean = mean_activity.mean()  # Calculate overall mean
        x = x + overall_mean  # Add back overall mean to stabilize around mean activity
        # print(f"Shape after the Peer Normalization is {x.shape}")
        return x

In [None]:
class LocalReceptiveFieldLayer(nn.Module):
    """
    A custom layer with local receptive fields (similar to a convolutional layer)
    that uses ReLU activation and a custom training loop to distinguish between
    positive and negative data based on a specified threshold.

    Attributes:
        conv (nn.Conv2d): Convolutional layer.
        batch_norm (nn.BatchNorm2d): Optional batch normalization layer for stability.
        relu (nn.ReLU): ReLU activation function.
        opt (Adam): Optimizer for the layer parameters.
        threshold (float): Threshold for distinguishing positive and negative samples.
        num_epochs (int): Number of training epochs.
        loss_values (list): List to store loss values for each training epoch.
        g_pos_values (list): List to store goodness metric values for positive samples.
        g_neg_values (list): List to store goodness metric values for negative samples.
    """

    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, padding='same',
                 num_epochs=500, threshold=2.0, lr=0.03):
        """
        Initializes the LocalReceptiveFieldLayer with specified hyperparameters.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int or tuple): Size of the convolving kernel.
            stride (int or tuple): Stride of the convolution.
            padding (int or tuple): Zero-padding added to both sides of the input.
            num_epochs (int): Number of training epochs.
            threshold (float): Threshold value to distinguish positive from negative samples.
            lr (float): Learning rate for the optimizer.
        """
        super(LocalReceptiveFieldLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size
                              ,stride=stride, padding=padding)
        self.batch_norm = nn.BatchNorm2d(out_channels)  # Optional batch normalization
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=lr)
        self.threshold = threshold
        self.num_epochs = num_epochs
        self.loss_values = []
        self.g_pos_values = []
        self.g_neg_values = []

    def forward(self, x):
        """
        Forward pass through the layer, applying batch normalization (optional)
        and ReLU activation after the convolution.

        Args:
            x (torch.Tensor): Input tensor of shape (batch, channels, height, width).

        Returns:
            torch.Tensor: Output tensor after applying convolution, optional batch
                          normalization, and ReLU activation.
        """
        x = self.conv(x)
        x = self.batch_norm(x)  # Apply batch normalization for stability
        x = self.relu(x)
        return x

    def custom_train(self, x_pos, x_neg):
        """
        Custom training loop to differentiate positive and negative samples based
        on a threshold.

        The loop optimizes the layer by calculating a goodness metric for both
        positive and negative samples, with a loss function designed to push
        the goodness above the threshold for positive samples and below for
        negative samples.

        Args:
            x_pos (torch.Tensor): Positive samples tensor.
            x_neg (torch.Tensor): Negative samples tensor.

        Returns:
            tuple: Tensors containing the final forward pass outputs for positive
                   and negative samples.
        """
        for epoch in tqdm(range(self.num_epochs), desc="Training Progress"):
            # Calculate goodness metrics for positive and negative samples
            g_pos = self.forward(x_pos).pow(2).mean(dim=[1, 2, 3])
            g_neg = self.forward(x_neg).pow(2).mean(dim=[1, 2, 3])

            # Calculate the threshold-based loss
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()

            # Perform backpropagation and optimization
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            # Log values every 10 epochs
            if epoch % 10 == 0:
                self.loss_values.append(loss.item())
                self.g_pos_values.append(g_pos.mean().item())
                self.g_neg_values.append(g_neg.mean().item())

        # Return the final forward pass outputs
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [None]:
class SimpleLayerWithPeerNormalization(nn.Module):
    """
    A custom fully connected layer with peer normalization for stabilizing neuron
    activity and a custom training loop to separate positive and negative samples
    based on a specified threshold.

    Attributes:
        linear (nn.Linear): Fully connected layer.
        relu (nn.ReLU): ReLU activation function.
        peer_norm (PeerNormalization): Peer normalization layer to stabilize activations.
        opt (Adam): Optimizer for the layer parameters.
        threshold (float): Threshold for distinguishing positive and negative samples.
        num_epochs (int): Number of training epochs.
        loss_values (list): List to store loss values for each training epoch.
        g_pos_values (list): List to store goodness metric values for positive samples.
        g_neg_values (list): List to store goodness metric values for negative samples.
    """

    def __init__(self, in_features, out_features, num_epochs=500, threshold=1.5, lr=0.03):
        """
        Initializes the SimpleLayerWithPeerNormalization with specified hyperparameters.

        Args:
            in_features (int): Number of input features.
            out_features (int): Number of output features.
            num_epochs (int): Number of training epochs.
            threshold (float): Threshold value to separate positive and negative samples.
            lr (float): Learning rate for the optimizer.
        """
        super(SimpleLayerWithPeerNormalization, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        self.peer_norm = PeerNormalization()  # Apply peer normalization to stabilize activities
        self.opt = Adam(self.parameters(), lr=lr)
        self.threshold = threshold
        self.num_epochs = num_epochs
        self.loss_values = []
        self.g_pos_values = []
        self.g_neg_values = []

    def forward(self, x):
        """
        Forward pass through the layer with peer normalization and ReLU activation.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after applying peer normalization, linear transformation,
                          and ReLU activation.
        """
        # Ensure proper flattening
        x_flattened = x.view(x.size(0), -1)  # Flatten for batch processing

        # Ensure the shape matches the expected in_features of the linear layer
        assert x_flattened.size(1) == self.linear.in_features, (
            f"Flattened input size {x_flattened.size(1)} does not match expected in_features {self.linear.in_features}"
        )

        x = self.linear(x_flattened)  # Apply linear transformation

        # ReLU activation
        x = self.relu(x)

        x = self.peer_norm(x)  # Apply peer normalization

        return x

    def custom_train(self, x_pos, x_neg):
        """
        Custom training loop to differentiate positive and negative samples based
        on a threshold.

        The loop optimizes the layer by calculating a goodness metric for both
        positive and negative samples, with a loss function designed to push
        the goodness above the threshold for positive samples and below for
        negative samples.

        Args:
            x_pos (torch.Tensor): Positive samples tensor.
            x_neg (torch.Tensor): Negative samples tensor.

        Returns:
            tuple: Tensors containing the final forward pass outputs for positive
                   and negative samples.
        """
        for epoch in tqdm(range(self.num_epochs), desc="Training Progress"):
            # Calculate goodness metrics for positive and negative samples
            g_pos = self.forward(x_pos).pow(2).mean(dim=1)
            g_neg = self.forward(x_neg).pow(2).mean(dim=1)

            # Calculate the threshold-based loss
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()

            # Perform backpropagation and optimization
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()


            # Log values every 10 epochs
            if epoch % 10 == 0:
                self.loss_values.append(loss.item())
                self.g_pos_values.append(g_pos.mean().item())
                self.g_neg_values.append(g_neg.mean().item())

        # Return the final forward pass outputs
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [None]:
class DeepResidualNetwork(nn.Module):
    """
    A deep residual network with skip connections, designed for the Forward-Forward (FF) algorithm.
    This network includes a feature extractor and classifier, with methods for forward pass,
    prediction, and custom training.

    Attributes:
        feature_extractor (nn.ModuleList): List of layers in the feature extraction part of the network.
        classifier (nn.ModuleList): List of layers in the classification part of the network.
        threshold (float): Threshold for distinguishing positive and negative samples.
        num_classes (int): Number of output classes.
    """

    def __init__(self, input_channels, hidden_channels, output_features, num_blocks=6, threshold=1.5, num_classes=10):
        """
        Initializes the DeepResidualNetwork with specified parameters.

        Args:
            input_channels (int): Number of input channels.
            hidden_channels (int): Number of hidden channels in each layer.
            output_features (int): Number of output features (classes).
            num_blocks (int): Number of residual blocks to include.
            threshold (float): Threshold for distinguishing positive from negative samples.
            num_classes (int): Number of classes for classification.
        """
        super(DeepResidualNetwork, self).__init__()
        self.threshold = threshold
        self.num_classes = num_classes

        # Feature extractor with residual blocks
        self.feature_extractor = nn.ModuleList()
        self.feature_extractor.append(LocalReceptiveFieldLayer(input_channels, hidden_channels))
        for _ in range(num_blocks):
            self.feature_extractor.append(LocalReceptiveFieldLayer(hidden_channels, hidden_channels))

        # Determine flattened feature size after feature extractor
        dummy_input = torch.zeros(1, input_channels, 32, 32)  # CIFAR-10 size example
        with torch.no_grad():
            dummy_output = dummy_input
            for layer in self.feature_extractor:
                dummy_output = layer(dummy_output)
            flattened_size = dummy_output.view(1, -1).size(1)

        # Classifier layers
        self.classifier = nn.ModuleList()
        self.classifier.append(SimpleLayerWithPeerNormalization(flattened_size, hidden_channels))
        self.classifier.append(SimpleLayerWithPeerNormalization(hidden_channels, 32))
        self.classifier.append(nn.Linear(32, output_features))

    def forward(self, x):
        """
        Forward pass through the network, applying each layer to the input data.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after processing through feature extractor and classifier.
        """
        residual = x
        for layer in self.feature_extractor:
            out = layer(residual)
            x = out + residual  # Apply residual connection
            residual = x  # Update residual for the next layer

        x = x.view(x.size(0), -1)  # Flatten for fully connected layers
        # print(f"Shape after flattening: {x.shape}")
        for layer in self.classifier:
            x = layer(x)
        return x

    def predict(self, x):
        """
        Predict function to compute the goodness per label for each sample in the batch.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Predicted labels for each sample in the batch.
        """
        goodness_per_label = []

        for label in range(self.num_classes):  # Iterate over all possible labels
            h = x.clone()  # Use the original data for prediction

            # Pass through feature extractor
            for layer in self.feature_extractor:
                h = layer(h)

            # Flatten before classifier
            h = h.view(h.size(0), -1)

            # Pass through classifier layers
            for layer in self.classifier:
                h = layer(h)

            # Compute goodness for this label
            goodness = h.pow(2).sum(dim=1)  # Goodness metric for the batch
            goodness_per_label.append(goodness.unsqueeze(1))

        # Combine goodness for all labels into one tensor
        goodness_per_label = torch.cat(goodness_per_label, dim=1)

        # Predicted labels are those with the maximum goodness
        return goodness_per_label.argmax(dim=1)

    def custom_train(self, x_pos, x_neg):
        """
        Custom training method that adjusts network weights layer-by-layer based on the Forward-Forward Algorithm.

        Args:
            x_pos (torch.Tensor): Positive samples.
            x_neg (torch.Tensor): Negative samples.

        Returns:
            tuple: Final forward pass outputs for positive and negative samples.
        """
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(list(self.feature_extractor) + list(self.classifier)):
            if hasattr(layer, 'custom_train'):
                print(f"Training layer {i}: {layer.__class__.__name__}")
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            else:
                h_pos = layer(h_pos)
                h_neg = layer(h_neg)
        return h_pos, h_neg

In [None]:
class SequentialDeepNetwork(nn.Module):
    """
    A deep sequential network designed for the Forward-Forward (FF) algorithm.
    This network uses a sequential stack of custom layers, including local receptive fields
    and fully connected layers with peer normalization for feature extraction and classification.
    """

    def __init__(self, input_channels, hidden_channels, output_features, num_blocks=6, pool_kernel=2, num_classes=10):
        """
        Initializes the SequentialDeepNetwork with specified parameters.

        Args:
            input_channels (int): Number of input channels.
            hidden_channels (int): Number of hidden channels in each layer.
            output_features (int): Number of output features (classes).
            num_blocks (int): Number of blocks in the feature extraction layers.
            pool_kernel (int): Size of the max-pooling kernel.
            num_classes (int): Number of classes for classification.
        """
        super(SequentialDeepNetwork, self).__init__()
        self.num_classes = num_classes

        # Feature extraction layers
        layers = []
        layers.append(LocalReceptiveFieldLayer(input_channels, hidden_channels))
        layers.append(nn.AdaptiveAvgPool2d((16,16)))  # Max-pooling layer
        for _ in range(num_blocks - 1):
            layers.append(LocalReceptiveFieldLayer(hidden_channels, hidden_channels))
            layers.append(nn.AdaptiveAvgPool2d((8,8)))  # Max-pooling layer

        layers.append(LocalReceptiveFieldLayer(hidden_channels, hidden_channels))
        layers.append(nn.AdaptiveAvgPool2d((4, 4)))

        self.feature_extractor = nn.Sequential(*layers)

        # Determine the flattened feature size dynamically
        # dummy_input = torch.zeros(1, input_channels, 32, 32)  # Example input size for CIFAR-10
        # with torch.no_grad():
        #     dummy_output = self.feature_extractor(dummy_input)
        #     flattened_size = dummy_output.view(1, -1).size(1)

        # Classification layers
        self.classifier = nn.Sequential(
            SimpleLayerWithPeerNormalization(hidden_channels*4*4, hidden_channels),
            nn.Dropout(0.2),
            SimpleLayerWithPeerNormalization(hidden_channels, 32),
            nn.Dropout(0.2),
            nn.Linear(32, output_features)
        )

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after processing through feature extractor and classifier.
        """
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # Flatten for fully connected layers
        x = self.classifier(x)
        return x

    def predict(self, x):
        """
        Predict function to compute the goodness per label for each sample in the batch.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Predicted labels for each sample in the batch.
        """
        goodness_per_label = []

        for label in range(self.num_classes):  # Iterate over all possible labels
            # Overlay label information on the input
            data_generator = NegativeDataGenerator(num_classes=self.num_classes)
            y = torch.full((x.size(0),), label, dtype=torch.long, device=x.device)
            h = data_generator._overlay_y_on_x(x, y)

            # Pass through feature extractor
            h = self.feature_extractor(h)

            # Flatten before classifier
            h = h.view(h.size(0), -1)

            # Pass through classifier layers
            h = self.classifier(h)

            # Compute goodness for this label
            goodness = h.pow(2).sum(dim=1)  # Goodness metric for the batch
            goodness_per_label.append(goodness.unsqueeze(1))

        # Combine goodness for all labels into one tensor
        goodness_per_label = torch.cat(goodness_per_label, dim=1)

        # Predicted labels are those with the maximum goodness
        return goodness_per_label.argmax(dim=1)

    def custom_train(self, x_pos, x_neg):
        """
        Custom training method that adjusts network weights layer-by-layer based on the Forward-Forward Algorithm.

        Args:
            x_pos (torch.Tensor): Positive samples.
            x_neg (torch.Tensor): Negative samples.

        Returns:
            tuple: Final forward pass outputs for positive and negative samples.
        """
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(list(self.feature_extractor) + list(self.classifier)):
            if hasattr(layer, 'custom_train'):
                print(f"Training layer {i}: {layer.__class__.__name__}")
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            else:
                h_pos = layer(h_pos)
                h_neg = layer(h_neg)
        return h_pos, h_neg


In [None]:
input_channels = 3  # CIFAR-10-like data
hidden_channels = 64
output_features = 10
num_blocks = 8

Net = SequentialDeepNetwork(input_channels, hidden_channels, output_features, num_blocks=num_blocks)
Net.to(device)

In [None]:
from torchinfo import summary
summary(Net, input_size=(32, 3, 32, 32))

In [None]:
# Reducing the pixel information for better visibility of patterns to model
x, x_neg = x / 255.0, x_neg / 255.0

In [None]:
Net.custom_train(x, x_neg)

In [None]:
# Save the trained model
torch.save(Net.state_dict(), "ff_sequential_model.pth")
print("Model saved successfully!")

In [None]:
# Load the saved state_dict
Net.load_state_dict(torch.load("ff_sequential_model.pth",weights_only=True))

# Set the model to evaluation mode
Net.eval()

In [None]:
def calculate_training_error(Net, train_loader, device):
    """
    Calculate the training error for the given network using the predict method.

    Args:
        Net (nn.Module): The trained network.
        train_loader (DataLoader): DataLoader for the training dataset.
        device (torch.device): Device to perform computations on (CPU or GPU).

    Returns:
        float: The training error as a percentage.
    """
    Net.eval()  # Set the network to evaluation mode
    total_samples = 0
    correct_predictions = 0

    with torch.no_grad():  # No gradient computation during evaluation
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)  # Move data to device
            predictions = Net.predict(x_batch)  # Use the predict method
            correct_predictions += (predictions == y_batch).sum().item()
            total_samples += y_batch.size(0)

    training_error = 1 - (correct_predictions / total_samples)
    return training_error * 100  # Return as a percentage

In [None]:
1.0 - Net.predict(x[0:6]).eq(y[0:6]).float().mean().item()

In [None]:
training_error = calculate_training_error(Net, train_loader, device)
print(f"Training Error: {training_error:.2f}%")

In [None]:
# Class names for CIFAR-10
class_names = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Store correct and incorrect predictions
correct_preds = []
incorrect_preds = []

# Run inference and collect predictions
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        predictions = Net.predict(images)

        # Compare predictions with true labels
        for i in range(len(labels)):
            if predictions[i] == labels[i]:
                correct_preds.append((images[i].cpu(), labels[i].cpu()))
            else:
                incorrect_preds.append((images[i].cpu(), predictions[i].cpu(), labels[i].cpu()))

# Visualize Correct Predictions
def visualize_correct_predictions(correct_preds, num_samples=5):
    plt.figure(figsize=(12, 6))
    for i in range(num_samples):
        image, label = correct_preds[i]
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(image.permute(1, 2, 0))  # Convert to HWC format
        plt.title(f"Correct: {class_names[label]}")
        plt.axis("off")
    plt.show()

# Visualize Incorrect Predictions
def visualize_incorrect_predictions(incorrect_preds, num_samples=5):
    plt.figure(figsize=(12, 6))
    for i in range(num_samples):
        image, predicted, true = incorrect_preds[i]
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(image.permute(1, 2, 0))  # Convert to HWC format
        plt.title(f"Pred: {class_names[predicted]}\nTrue: {class_names[true]}")
        plt.axis("off")
    plt.show()

# Display results
print("Correct Predictions:")
visualize_correct_predictions(correct_preds, num_samples=7)

print("Incorrect Predictions:")
visualize_incorrect_predictions(incorrect_preds, num_samples=7)

# Some Manual Debugging of Network Parameters

In [None]:
# Function to visualize and trace tensor shapes
def test_feature_extractor_with_real_data(Net, sample_image):
    """
    Test the feature extractor using a real sample image and trace the tensor shapes.

    Args:
        Net (nn.Module): The network object containing the feature extractor and classifier.
        sample_image (torch.Tensor): A single sample image from the dataset.
    """
    print("Input image shape:", sample_image.shape)

    # Visualize the input image
    plt.imshow(sample_image.permute(1, 2, 0).cpu().numpy())  # Adjust permutation for visualization
    plt.title("Input Image")
    plt.show()

    # Send the image through the feature extractor
    x = sample_image.unsqueeze(0)  # Add batch dimension
    for i, layer in enumerate(Net.feature_extractor):
        x = layer(x)
        print(f"Shape after feature extractor layer {i}: {x.shape}")

    # Flatten the output and check shape
    x_flattened = x.view(x.size(0), -1)
    print("Shape after flattening:", x_flattened.shape)

    # Pass through the classifier
    for i, layer in enumerate(Net.classifier):
        x_flattened = layer(x_flattened)
        print(f"Shape after classifier layer {i}: {x_flattened.shape}")

In [None]:
# Select a sample image
sample_image, label = x[0], y[0]  # Select the first image in the dataset

In [None]:
sample_image.shape

In [None]:
label

In [None]:
# Test the feature extractor with a real image
test_feature_extractor_with_real_data(Net, sample_image)

In [None]:
lab = labels.tolist()

In [None]:
from collections import Counter
Counter(lab)