In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from IPython.display import clear_output
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:

# Data Loader
class DataLoaderManager:
  """
  DataLoader manager class to handle data loading for various datasets.
  """
  def __init__(self, dataset_name, train_batch_size=50000, test_batch_size=10000):
        """
        Initializes the DataLoaderManager with specified dataset and batch sizes.

        Args:
        dataset_name (str): Name of the dataset ('MNIST', 'CIFAR10', 'FashionMNIST').
        train_batch_size (int): Batch size for the training dataset.
        test_batch_size (int): Batch size for the test dataset.
        """
        self.dataset_name = dataset_name
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.transform = self.get_transforms()

  def get_transforms(self):
        """
        Returns the appropriate transformation for each dataset.

        Returns:
        torchvision.transforms.Compose: Transformation pipeline.
        """
        if self.dataset_name == 'CIFAR10':
            return Compose([
                ToTensor(),
                Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
            ])
        else:  # Default to MNIST and FashionMNIST normalization
            return Compose([
                ToTensor(),
                Normalize((0.1307,), (0.3081,)),
                Lambda(lambda x: torch.flatten(x))
            ])

  def load_dataset(self, train=True):
        """
        Loads the specified dataset.

        Args:
        train (bool): If True, load training dataset; otherwise, load test dataset.

        Returns:
        Dataset: The requested dataset.
        """
        if self.dataset_name == 'MNIST':
            return MNIST('./data/', train=train, download=True, transform=self.transform)
        elif self.dataset_name == 'CIFAR10':
            return CIFAR10('./data/', train=train, download=True, transform=self.transform)
        elif self.dataset_name == 'FashionMNIST':
            return FashionMNIST('./data/', train=train, download=True, transform=self.transform)

  def get_data_loaders(self):
        """
        Creates and returns data loaders for training and testing datasets.

        Returns:
        tuple: Tuple containing the training and testing data loaders.
        """
        train_loader = DataLoader(self.load_dataset(train=True),
                                  batch_size=self.train_batch_size, shuffle=True)
        test_loader = DataLoader(self.load_dataset(train=False),
                                 batch_size=self.test_batch_size, shuffle=False)
        return train_loader, test_loader

In [None]:
def overlay_y_on_x(x, y, classes=10):
    """
    Modify the input tensor x by zeroing out the first 'classes' pixels in the width
    of the first channel of each image in the batch, and set the pixel corresponding
    to the label y to the maximum value in the tensor x.

    Args:
    x (torch.Tensor): The input tensor, expected to be 4D (batch, channels, height, width).
    y (torch.Tensor): The labels corresponding to each item in the batch.
    classes (int): The number of classes or width of the area to be zeroed and used for encoding.

    Returns:
    torch.Tensor: The modified tensor.
    """
    x_ = x.clone()  # Clone the tensor to avoid modifying the original in place.
    max_value = x.max()  # Get the maximum value from the entire tensor.
    x_[range(x.shape[0]), :, 0, :classes] *= 0.0  # Zero out the first 'classes' pixels in the width.
    x_[range(x.shape[0]), :, 0, y] = max_value  # Set the pixel at the label index to the maximum value.
    return x_
    # x_ = x.clone()  # Clone the tensor to avoid modifying the original in place.
    # max_value = x.max()  # Get the maximum value from the entire tensor.

    # # Check if the tensor is 2D (flattened images), and reshape it to 4D if necessary
    # if x_.dim() == 2:
    #     # Assuming images are 28x28 after flattening to 784
    #     batch_size = x_.shape[0]
    #     height = width = int((x_.shape[1] // classes) ** 0.5)  # Simplified square root calculation
    #     x_ = x_.view(batch_size, 1, height, width)

    # # Zero out the first 'classes' pixels in the width
    # x_[:, :, 0, :classes] *= 0
    # # Set the pixel at the label index to the maximum value
    # for idx in range(x_.shape[0]):
    #     if y[idx] < classes:
    #         x_[idx, :, 0, y[idx]] = max_value

    # return x_

In [None]:
def overlay_y_on_x_1D(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

In [None]:
def get_y_neg(y, num_classes=10):
    """
    Generates negative labels for a batch of labels by ensuring each negative label
    is different from the original label. Used in contrastive learning setups.

    Args:
    y (torch.Tensor): A tensor containing a batch of labels.
    num_classes (int): The total number of classes.

    Returns:
    torch.Tensor: A tensor containing a batch of negative labels.
    """
    y_neg = y.clone()  # Clone the original labels to create a new tensor for negative labels.
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(num_classes))  # Create a list of all possible class indices.
        allowed_indices.remove(y_samp.item())  # Remove the original label to ensure the negative label is different.
        # Randomly select a new label from the remaining allowed indices.
        y_neg[idx] = torch.tensor(allowed_indices)[torch.randint(len(allowed_indices), size=(1,))].item()
    return y_neg

In [None]:
class ConvLayer(nn.Module):
    """
    A convolutional layer class that implements the Forward-Forward Algorithm for distinguishing
    between positive and negative samples by adjusting the layer weights based on a specified threshold.

    Attributes:
        conv (nn.Conv2d): Convolutional layer.
        relu (nn.ReLU): Rectified Linear Unit activation function.
        opt (Adam): Optimizer for updating layer weights.
        threshold (float): Threshold for differentiating positive from negative samples.
        num_epochs (int): Number of epochs for the custom training loop.
        log_interval (int): Interval at which to log training progress.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True, lr=0.01, threshold=0.1, num_epochs=100, log_interval=10):
        """
        Initializes the ConvLayer with convolutional operation and training parameters.

        Args:
            in_channels (int): Number of channels in the input.
            out_channels (int): Number of channels produced by the convolution.
            kernel_size (int or tuple): Size of the convolving kernel.
            stride (int or tuple): Stride of the convolution.
            padding (int or tuple): Zero-padding added to both sides of the input.
            bias (bool): If True, adds a learnable bias to the output.
            lr (float): Learning rate for the optimizer.
            threshold (float): Goodness threshold for differentiating samples.
            num_epochs (int): Total number of training epochs.
            log_interval (int): Interval for logging training progress.
        """
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=lr)
        self.threshold = threshold
        self.num_epochs = num_epochs
        self.log_interval = log_interval

    def forward(self, x):
        """
        Forward pass of the ConvLayer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Activated output after the convolution and normalization.
        """
        # L2 normalization on the input tensor
        x_direction = x / (x.norm(p=2, dim=1, keepdim=True) + 1e-6)
        x = self.conv(x_direction)
        return self.relu(x)

    def custom_train(self, x_pos, x_neg):
        """
        Custom training loop to adjust weights based on the positive and negative samples.

        Args:
            x_pos (torch.Tensor): Positive samples.
            x_neg (torch.Tensor): Negative samples.

        Returns:
            tuple: Detached tensors of forward passes from positive and negative samples after training.
        """
        loss_values, g_pos_values, g_neg_values = [], [], []

        for i in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(dim=(1, 2, 3)).unsqueeze(0)
            g_neg = self.forward(x_neg).pow(2).mean(dim=(1, 2, 3)).unsqueeze(0)
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            loss.backward(retain_graph=True)
            self.opt.step()

            if i % self.log_interval == 0:
                loss_values.append(loss.item())
                g_pos_values.append(g_pos.mean().item())
                g_neg_values.append(g_neg.mean().item())
                self._plot_training_progress(loss_values, g_pos_values, g_neg_values, i)

        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

    def _plot_training_progress(self, loss_values, g_pos_values, g_neg_values, step):
        """
        Plots the training progress.

        Args:
            loss_values (list): Recorded loss values.
            g_pos_values (list): Recorded values of positive samples' goodness.
            g_neg_values (list): Recorded values of negative samples' goodness.
            step (int): Current training step.
        """
        clear_output(wait=True)
        plt.figure(figsize=(12, 8))
        plt.subplot(3, 1, 1)
        plt.plot(loss_values, color='blue')
        plt.title(f"Loss during training at step {step}")

        plt.subplot(3, 1, 2)
        plt.plot(g_pos_values, color='green')
        plt.title("g_pos during training")

        plt.subplot(3, 1, 3)
        plt.plot(g_neg_values, color='red')
        plt.title("g_neg during training")

        plt.tight_layout()
        plt.show()

In [None]:
class Layer(nn.Linear):
    """
    A custom layer class that extends nn.Linear for the purpose of training with
    a specific loss function that discriminates between positive and negative samples.

    Attributes:
        relu (nn.ReLU): Activation function.
        opt (Adam): Optimizer for the layer parameters.
        threshold (float): Threshold value to separate positive and negative samples.
        num_epochs (int): Number of training epochs.
        loss_values (list): List to store loss values for visualization.
        g_pos_values (list): List to store values of goodness metric for positive samples.
        g_neg_values (list): List to store values of goodness metric for negative samples.
    """
    def __init__(self, in_features, out_features,kernel_size=3,stride=None,padding=None, bias=True, device=None, dtype=None, num_epochs=1000, threshold=2.0, lr=0.03):
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=lr)
        self.threshold = threshold
        self.num_epochs = num_epochs
        self.loss_values = []
        self.g_pos_values = []
        self.g_neg_values = []

    def forward(self, x):
        """
        Forward pass of the layer which normalizes the input before applying the linear transformation and activation.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after applying linear transformation and ReLU activation.
        """
        x_normalized = x / (x.norm(p=2, dim=1, keepdim=True) + 1e-4)
        return self.relu(torch.mm(x_normalized, self.weight.T) + self.bias.unsqueeze(0))

    def custom_train(self, x_pos, x_neg):
        """
        Custom training loop that adjusts weights based on the comparison of positive and negative samples
        against a set threshold.

        Args:
            x_pos (torch.Tensor): Positive samples.
            x_neg (torch.Tensor): Negative samples.

        Returns:
            tuple: Detached tensors of the outputs from the final forward pass of positive and negative samples.
        """
        for i in tqdm(range(self.num_epochs)):
            g_pos = self.forward(x_pos).pow(2).mean(dim=1)
            g_neg = self.forward(x_neg).pow(2).mean(dim=1)
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            if i % 10 == 0:  # Log at every 10th epoch
                self.loss_values.append(loss.item())
                self.g_pos_values.append(g_pos.mean().item())
                self.g_neg_values.append(g_neg.mean().item())
                # self.plot_training_progress()

        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [None]:
class Net(nn.Module):
    """
    A neural network that incorporates both convolutional and fully connected layers,
    specifically designed to distinguish between positive and negative samples using
    the Forward-Forward Algorithm. This network adjusts to different image sizes.
    """
    def __init__(self, input_channels, num_classes=10, input_size=28):
        super(Net, self).__init__()
        self.feature_extractor = nn.Sequential(
            ConvLayer(input_channels, 16, kernel_size=5, padding=2),
            ConvLayer(16, 32, kernel_size=5, stride=2, padding=2),
            ConvLayer(32, 64, kernel_size=5, stride=2, padding=2),
        )

        # The number of output features from the last conv layer needs to be determined dynamically
        # Dummy input to calculate the size of the feature maps after conv layers
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_channels, input_size, input_size)  # Minimal size for CIFAR-10 and MNIST
            output_features = self.feature_extractor(dummy_input).numel()

        self.classifier = nn.Sequential(
            nn.Flatten(),
            Layer(output_features, 128),
            Layer(128, num_classes)
        )

    def forward(self, x):
        """
        Forward pass through the network, applying each layer to the input data.
        """
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x

    def predict(self, x):
        """
        Predict function to compute the goodness per label for each sample in the batch.
        """
        goodness_per_label = []
        for label in range(10):  # Assuming 10 possible classes
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.feature_extractor.children():
                h = layer(h)
            for layer in self.classifier.children():
                h = layer(h)
            goodness.append((h.pow(2).sum() / h.numel()).unsqueeze(0))
            goodness_per_label.append(torch.sum(torch.stack(goodness)).unsqueeze(0))
        goodness_per_label = torch.cat(goodness_per_label, 0)
        return goodness_per_label.argmax(0)

    def custom_train(self, x_pos, x_neg):
        """
        Custom training method that adjusts network weights layer-by-layer based on the Forward-Forward Algorithm.
        """
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(list(self.feature_extractor) + list(self.classifier)):
            if hasattr(layer, 'custom_train'):
                print(f"Training layer: {i}")
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            else:
                h_pos = layer(h_pos)
                h_neg = layer(h_neg)
        return h_pos, h_neg

In [None]:
# class Net(torch.nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.layers = nn.Sequential(
#             Layer(3, 64, kernel_size=5, padding=2),  # Layer 0
#             nn.ReLU(inplace=True),

#             Layer(64, 128, kernel_size=5, padding=2, stride=2),  # Layer 2
#             nn.ReLU(inplace=True),

#             Layer(128, 256, kernel_size=5, padding=2, stride=2),  # Layer 4
#             nn.ReLU(inplace=True),

#             Layer(256, 512, kernel_size=5, padding=2, stride=2),  # Layer 6
#             nn.ReLU(inplace=True),

#             Layer(512, 1024, kernel_size=5, padding=2, stride=2),  # Layer 8
#             nn.ReLU(inplace=True),

#             Layer(1024, 2048, kernel_size=5, padding=2, stride=2),  # Layer 10
#             nn.ReLU(inplace=True),

#             Layer(2048, 10, kernel_size=1, stride=1),  # Layer 12
#             nn.ReLU(inplace=True),
#         )
#     def forward(self, x):
#         return self.layers(x)

#     def predict(self, x):
#         goodness_per_label = []
#         for label in range(10):
#             h = overlay_y_on_x(x, label)
#             goodness = []
#             for layer in self.layers:
#                 h = layer(h)
#                 goodness += [(h.pow(2).sum() / h.numel()).unsqueeze(0)]
#             goodness_per_label += [torch.sum(torch.stack(goodness)).unsqueeze(0)]
#         goodness_per_label = torch.cat(goodness_per_label, 0)
#         return goodness_per_label.argmax(0)

#     def custom_train(self, x_pos, x_neg):
#         h_pos, h_neg = x_pos, x_neg
#         for i, layer in enumerate(self.layers):
#             print("training layer: ", i)
#             if isinstance(layer, Layer):  # only call custom_train on instances of the Layer class
#                 h_pos, h_neg = layer.custom_train(h_pos, h_neg)
#             elif isinstance(layer, ConvLayer):  # only call custom_train on instances of the FullyConnectedLayer class
#                 h_pos, h_neg = layer.custom_train(h_pos, h_neg)
#             else:  # for other layers, just pass the data through
#                 h_pos = layer(h_pos)
#                 h_neg = layer(h_neg)

In [None]:
# class Net(nn.Module):
#     """
#     A neural network that incorporates both convolutional and fully connected layers specifically designed
#     to distinguish between positive and negative samples using the Forward-Forward Algorithm.
#     """
#     def __init__(self, input_channels, num_classes=10):
#         super(Net, self).__init__()
#         self.feature_extractor = nn.Sequential(
#             ConvLayer(input_channels, 16, kernel_size=5, padding=2),  # Assume input is image data
#             ConvLayer(16, 32, kernel_size=5, stride=2, padding=2),
#             ConvLayer(32, 64, kernel_size=5, stride=2, padding=2),
#         )

#         self.classifier = nn.Sequential(
#             nn.Flatten(),
#             Layer(64 * 7 * 7, 128),  # Adjust the flattened size according to input dimension
#             Layer(128, num_classes)
#         )

#         self.layers = nn.ModuleList([*self.feature_extractor, *self.classifier])  # Consolidate all layers


#     def forward(self, x):
#         """
#         Forward pass through the network, applying each layer to the input data.

#         Args:
#             x (torch.Tensor): The input tensor to the network.

#         Returns:
#             torch.Tensor: The output tensor from the network.
#         """
#         for layer in self.layers:
#             x = layer(x)
#         return x

#     def predict(self, x):
#         """
#         Predict function to compute the goodness per label for each sample in the batch.

#         Args:
#             x (torch.Tensor): The input tensor.

#         Returns:
#             torch.Tensor: The tensor containing the predicted class labels.
#         """
#         goodness_per_label = []
#         for label in range(10):  # Assuming 10 possible classes
#             h = overlay_y_on_x(x, label)
#             goodness = []
#             for layer in self.layers:
#                 h = layer(h)
#                 goodness.append((h.pow(2).sum() / h.numel()).unsqueeze(0))
#             goodness_per_label.append(torch.sum(torch.stack(goodness)).unsqueeze(0))
#         goodness_per_label = torch.cat(goodness_per_label, 0)
#         return goodness_per_label.argmax(0)

#     def custom_train(self, x_pos, x_neg):
#         """
#         Custom training method that adjusts network weights layer-by-layer based on the Forward-Forward Algorithm.

#         Args:
#             x_pos (torch.Tensor): Positive samples.
#             x_neg (torch.Tensor): Negative samples.
#         """
#         h_pos, h_neg = x_pos, x_neg
#         for i, layer in enumerate(self.layers):
#             if hasattr(layer, 'custom_train'):
#                 print(f"Training layer: {i}")
#                 h_pos, h_neg = layer.custom_train(h_pos, h_neg)
#             else:
#                 h_pos = layer(h_pos)
#                 h_neg = layer(h_neg)
#         # h_pos, h_neg = x_pos, x_neg
#         # # layers = list(self.feature_extractor) + list(self.classifier)
#         # for i, layer in enumerate(self.layers):
#         #     if hasattr(layer, 'train'):
#         #         print(f'Training layer {i}...')
#         #         if isinstance(layer, Layer):  # only call custom_train on instances of the Layer class
#         #         h_pos, h_neg = layer.custom_train(h_pos, h_neg)
#         #         # h_pos, h_neg = layer.train(h_pos, h_neg)
#         #         # layer.train() # Put the layer in training mode
#         #         # h_pos = layer(h_pos)  # Apply the layer to positive samples
#         #         # h_neg = layer(h_neg)  # Apply the layer to negative samples
#         #     else:
#         #         h_pos = layer(h_pos)
#         #         h_neg = layer(h_neg)

In [None]:
def visualize_sample(data, name='', idx=0):
    """
    Visualizes a single sample from the provided dataset, reshaping it into a 28x28 image
    if necessary, and displaying it with a title.

    Args:
        data (torch.Tensor): The dataset containing the samples.
        name (str): The title label for the image.
        idx (int): The index of the sample in the dataset to visualize.
    """
    # Assuming data is a PyTorch tensor and the image size is 28x28.
    # This will need to be adjusted if your data dimensions or tensor backend differ.
    reshaped = data[idx].cpu().reshape(28, 28)
    plt.figure(figsize=(4, 4))
    plt.title(name)
    plt.imshow(reshaped, cmap="gray")
    plt.axis('off')  # Optional: Turn off the axis.
    plt.show()


In [None]:
# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# Initialize the network with appropriate dimensions for MNIST (1 channel input, 10 classes)
net = Net()
net.to(device)  # Move the network to the configured device

In [None]:
# # Initialize the network with appropriate dimensions for MNIST (1 channel input, 10 classes)
# net = Net(input_channels=3, num_classes=10, input_size=32)
# net.to(device)  # Move the network to the configured device


In [None]:
# Initialize data loaders
data_manager = DataLoaderManager("CIFAR10")
train_loader, test_loader = data_manager.get_data_loaders()


In [None]:
# Fetch a single batch of data
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)  # Move data to the device


In [None]:
x.shape

In [None]:
y.shape

In [None]:
# Create positive and negative samples
x_pos = overlay_y_on_x(x, y)
y_neg = get_y_neg(y)
x_neg = overlay_y_on_x(x, y_neg)

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(10, 10))

# Define a dictionary to map class indices to class names (replace this with your actual classes)
class_dict = {i: 'class_' + str(i) for i in range(10)}

for i in range(5):
    img = x[i].cpu().numpy().transpose(1,2,0)
    pos_img = x_pos[i].cpu().numpy().transpose(1,2,0)
    neg_img = x_neg[i].cpu().numpy().transpose(1,2,0)

    axs[i, 0].imshow(img)
    axs[i, 0].set_title('Original: ' + class_dict[int(y[i])] + '\n Shape: ' + str(img.shape))

    axs[i, 1].imshow(pos_img)
    axs[i, 1].set_title('Positive: ' + class_dict[int(y[i])] + '\n Shape: ' + str(pos_img.shape))

    axs[i, 2].imshow(neg_img)
    axs[i, 2].set_title('Negative: ' + class_dict[int(y_neg[i])] + '\n Shape: ' + str(neg_img.shape))

for ax in axs.flat:
    ax.axis('off')

plt.tight_layout()  # Adjusts subplot params so that subplots are nicely fit in the figure
plt.show()

In [None]:
# for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):
#   visualize_sample(data, name)

In [None]:
# Train the network
net.custom_train(x_pos, x_neg)

In [None]:
print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item())