In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
from IPython.display import clear_output
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torch.optim import Adam

In [None]:
# Generates negative labels for the training data, which are required for contrastive divergence training
def get_y_neg(y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(10))
        allowed_indices.remove(y_samp.item())
        y_neg[idx] = torch.tensor(allowed_indices)[torch.randint(len(allowed_indices), size=(1,))].item()
    return y_neg.to(device)

In [None]:
def overlay_y_on_x(x, y, classes=10):
    x_ = x.clone()
    x_[range(x.shape[0]), :, 0, :classes] *= 0.0
    x_[range(x.shape[0]), :, 0, y] = 1
    return x_

In [None]:
# Define a Net class that inherits from torch.nn.Module which is the base class for all neural network modules in PyTorch.
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            Layer(3, 64, kernel_size=5, padding=2),  # Layer 0
            nn.ReLU(inplace=True),

            Layer(64, 128, kernel_size=5, padding=2, stride=2),  # Layer 2
            nn.ReLU(inplace=True),

            Layer(128, 256, kernel_size=5, padding=2, stride=2),  # Layer 4
            nn.ReLU(inplace=True),

            Layer(256, 512, kernel_size=5, padding=2, stride=2),  # Layer 6
            nn.ReLU(inplace=True),

            Layer(512, 1024, kernel_size=5, padding=2, stride=2),  # Layer 8
            nn.ReLU(inplace=True),

            Layer(1024, 2048, kernel_size=5, padding=2, stride=2),  # Layer 10
            nn.ReLU(inplace=True),

            Layer(2048, 10, kernel_size=1, stride=1),  # Layer 12
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.layers(x)

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness += [(h.pow(2).sum() / h.numel()).unsqueeze(0)]
            goodness_per_label += [torch.sum(torch.stack(goodness)).unsqueeze(0)]
        goodness_per_label = torch.cat(goodness_per_label, 0)
        return goodness_per_label.argmax(0)

    def custom_train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print("training layer: ", i)
            if isinstance(layer, Layer):  # only call custom_train on instances of the Layer class
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            elif isinstance(layer, FullyConnectedLayer):  # only call custom_train on instances of the FullyConnectedLayer class
                h_pos, h_neg = layer.custom_train(h_pos, h_neg)
            else:  # for other layers, just pass the data through
                h_pos = layer(h_pos)
                h_neg = layer(h_neg)

In [None]:
# Fully Connected layers of the network.
class FullyConnectedLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, final_layer=False):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=args.lr)
        self.fcl_threshold = args.fcl_threshold
        self.num_epochs = args.epochs
        self.final_layer = final_layer

    def forward(self, x):
        #L2 norm
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(self.linear(x_direction))

    def custom_train(self, x_pos, x_neg):
        for i in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean().unsqueeze(0)  # mean over all dimensions in a sample
            g_neg = self.forward(x_neg).pow(2).mean().unsqueeze(0)  # mean over all dimensions in a sample
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.fcl_threshold, g_neg - self.fcl_threshold]))).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            if i % args.log_interval == 0:
                print("Loss: ", loss.item())
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [None]:
# Convolutional layers of the network.
class Layer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True, final_layer=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=args.lr)
        self.conv_threshold = args.conv_threshold
        self.num_epochs = args.epochs
        self.final_layer = final_layer

    def forward(self, x):

        '''
        # trial 10
        # batch norm
        x = self.conv(x)
        return self.relu(x)
        '''

        # L2 norm
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(self.conv(x_direction))

    def custom_train(self, x_pos, x_neg):

        # initialize lists to hold values
        loss_values = []
        g_pos_values = []
        g_neg_values = []

        # initialize figure
        fig = plt.figure(figsize=(12,8))

        for i in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(dim=(1,2,3)).unsqueeze(0)  # mean over all dimensions
            g_neg = self.forward(x_neg).pow(2).mean(dim=(1,2,3)).unsqueeze(0)  # mean over all dimensions
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.conv_threshold, g_neg - self.conv_threshold]))).mean()
            self.opt.zero_grad()

            #loss.backward()
            loss.backward(retain_graph=True)

            self.opt.step()

            if i % args.log_interval == 0:
                loss_values.append(loss.item())
                g_pos_values.append(g_pos.mean().item())  # take mean of all batch values
                g_neg_values.append(g_neg.mean().item())  # take mean of all batch values

                # plotting
                plt.subplot(3,1,1)
                plt.plot(loss_values, color='blue')
                plt.title("Loss during training")

                plt.subplot(3,1,2)
                plt.plot(g_pos_values, color='green')
                plt.title("g_pos during training")

                plt.subplot(3,1,3)
                plt.plot(g_neg_values, color='red')
                plt.title("g_neg during training")

                plt.tight_layout()
                clear_output(wait=True)  # this clears the output of the cell, useful for updating the plots
                plt.show()

            # Print the loss at each step
            print(f'Loss at step {i}: {loss.item()}')

        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [None]:
class Args:
    train_size = 1000 #10000 #50000
    test_size = 100 #2000 #10000
    epochs = 1000
    lr = 0.05
    no_cuda = False
    no_mps = False
    save_model = False
    fcl_threshold = 1
    conv_threshold = 0.02
    seed = 1234
    log_interval = 10

args = Args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")
train_kwargs = {"batch_size": args.train_size}
test_kwargs = {"batch_size": args.test_size}
if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [None]:
# Load Data
transform = Compose(
    [
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_loader = DataLoader(CIFAR10("./data/", train=True, download=True, transform=transform), **train_kwargs)
test_loader = DataLoader(CIFAR10("./data/", train=False, download=True, transform=transform), **test_kwargs)

In [None]:
# Create Model
net = Net().to(device)

In [None]:
import torch
import torch.onnx

# Assuming `net` is your PyTorch model and it's already built
dummy_input = torch.randn(1, 3, 128, 128)  # Adjust input shape as per your model requirements
onnx_file_path = "model.onnx"

# Export the model to ONNX format
torch.onnx.export(
    net,                    # The PyTorch model
    dummy_input,            # A dummy input tensor for tracing
    onnx_file_path,         # File to save the ONNX model
    export_params=True,     # Store the trained parameter weights inside the model file
    opset_version=11,       # ONNX opset version (11 is commonly used)
    do_constant_folding=True,  # Perform constant folding for optimization
    input_names=["input"],  # Input tensor name
    output_names=["output"],  # Output tensor name
    dynamic_axes={
        "input": {0: "batch_size"},  # Allow dynamic batch size
        "output": {0: "batch_size"}
    }
)

print(f"Model has been exported to {onnx_file_path}")


In [None]:
import torchvision
from torchview import draw_graph

model_graph = draw_graph(net, input_size=(1,3,32,32),depth=3, expand_nested=False)

model_graph.resize_graph
model_graph.visual_graph


In [None]:
from torchviz import make_dot
import torch

# Create a dummy input to pass through the network
dummy_input = torch.randn(1, 3, 32, 32)  # Adjust input shape as per your data
output = net(dummy_input)
dot = make_dot(output, params=dict(net.named_parameters()), show_attrs=False, show_saved=False)
dot.render("network_architecture", format="png")

In [None]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
x_pos = overlay_y_on_x(x, y)
y_neg = get_y_neg(y)
x_neg = overlay_y_on_x(x, y_neg)

In [None]:
x.size()

In [None]:
x.size(0)

In [None]:
y.size()

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(10, 10))

# Define a dictionary to map class indices to class names (replace this with your actual classes)
class_dict = {i: 'class_' + str(i) for i in range(10)}

for i in range(5):
    img = x[i].cpu().numpy().transpose(1,2,0)
    pos_img = x_pos[i].cpu().numpy().transpose(1,2,0)
    neg_img = x_neg[i].cpu().numpy().transpose(1,2,0)

    axs[i, 0].imshow(img)
    axs[i, 0].set_title('Original: ' + class_dict[int(y[i])] + '\n Shape: ' + str(img.shape))

    axs[i, 1].imshow(pos_img)
    axs[i, 1].set_title('Positive: ' + class_dict[int(y[i])] + '\n Shape: ' + str(pos_img.shape))

    axs[i, 2].imshow(neg_img)
    axs[i, 2].set_title('Negative: ' + class_dict[int(y_neg[i])] + '\n Shape: ' + str(neg_img.shape))

for ax in axs.flat:
    ax.axis('off')

plt.tight_layout()  # Adjusts subplot params so that subplots are nicely fit in the figure
plt.show()

In [None]:
net.custom_train(x_pos, x_neg)

In [None]:
print("Train Accuracy: {:.2f}%".format(100 * net.predict(x).eq(y).float().mean().item()))

In [None]:
# Test Model
x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.to(device), y_te.to(device)
if args.save_model:
    torch.save(net.state_dict(), "cifar10_ff.pt")

In [None]:
print("Test Accuracy: {:.2f}%".format(100 * net.predict(x_te).eq(y_te).float().mean().item()))

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader


def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):

    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader


def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_


class Net(torch.nn.Module):

    def __init__(self, dims):
        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers += [Layer(dims[d], dims[d + 1]).cuda()]

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness += [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')
            h_pos, h_neg = layer.train(h_pos, h_neg)


class Layer(nn.Linear):
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 1000

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        for i in tqdm(range(self.num_epochs)):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            # The following loss pushes pos (neg) samples to
            # values larger (smaller) than the self.threshold.
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            # this backward just compute the derivative and hence
            # is not considered backpropagation.
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()


def visualize_sample(data, name='', idx=0):
    reshaped = data[idx].cpu().reshape(28, 28)
    plt.figure(figsize = (4, 4))
    plt.title(name)
    plt.imshow(reshaped, cmap="gray")
    plt.show()


if __name__ == "__main__":
    torch.manual_seed(1234)
    train_loader, test_loader = MNIST_loaders()

    net = Net([784, 500, 500])
    x, y = next(iter(train_loader))
    x, y = x.cuda(), y.cuda()
    x_pos = overlay_y_on_x(x, y)
    rnd = torch.randperm(x.size(0))
    x_neg = overlay_y_on_x(x, y[rnd])

    for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):
        visualize_sample(data, name)

    net.train(x_pos, x_neg)

    print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item())

    x_te, y_te = next(iter(test_loader))
    x_te, y_te = x_te.cuda(), y_te.cuda()

    print('test error:', 1.0 - net.predict(x_te).eq(y_te).float().mean().item())

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'


# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def CIFAR10_loaders(train_batch_size=5000, test_batch_size=1000):
    transform = Compose([
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10 normalization
    ])

    train_loader = DataLoader(
        CIFAR10('./data/', train=True, download=True, transform=transform),
        batch_size=train_batch_size, shuffle=True
    )

    test_loader = DataLoader(
        CIFAR10('./data/', train=False, download=True, transform=transform),
        batch_size=test_batch_size, shuffle=False
    )

    return train_loader, test_loader

def overlay_y_on_x(x, y, num_classes=10):
    """
    Overlay one-hot encoded labels onto the images.
    Note: CIFAR-10 images have shape [batch, channels, height, width].
    """
    x_ = x.clone()
    x_[:, :, :num_classes, 0] = 0  # Set the first few pixels to zero
    x_[range(x.shape[0]), :, y, 0] = x.max()  # Set pixel based on label
    return x_

class Net(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.layers = nn.ModuleList([Layer(dims[i], dims[i+1]).to(device) for i in range(len(dims)-1)])

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness.append(h.pow(2).mean((1, 2, 3)))  # Adjust mean for all spatial dims
            goodness_per_label.append(torch.stack(goodness).sum(0).unsqueeze(1))
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')
            h_pos, h_neg = layer.train(h_pos, h_neg)

class Layer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.conv = nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1).to(device)
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 1000

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)  # L2 normalization
        return self.relu(self.conv(x_direction))

    def train(self, x_pos, x_neg):
        for i in tqdm(range(self.num_epochs)):
            g_pos = self.forward(x_pos).pow(2).mean((1, 2, 3))  # Mean over spatial dimensions
            g_neg = self.forward(x_neg).pow(2).mean((1, 2, 3))
            # Loss that pushes pos (neg) samples to values above (below) threshold
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

def visualize_sample(data, name='', idx=0):
    reshaped = data[idx].cpu().permute(1, 2, 0)  # Adjust for RGB channels
    plt.figure(figsize=(4, 4))
    plt.title(name)
    plt.imshow(reshaped)
    plt.show()

if __name__ == "__main__":
    torch.manual_seed(1234)
    train_loader, test_loader = CIFAR10_loaders()

    # Update input size to [3, 32, 32] for CIFAR-10 with two hidden layers
    net = Net([3, 32, 32]).to(device)
    print(net)

    x, y = next(iter(train_loader))
    x, y = x.to(device), y.to(device)
    x_pos = overlay_y_on_x(x, y)
    rnd = torch.randperm(x.size(0))
    x_neg = overlay_y_on_x(x, y[rnd])

    for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):
        visualize_sample(data, name)

    net.train(x_pos, x_neg)

    print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item())

    x_te, y_te = next(iter(test_loader))
    x_te, y_te = x_te.to(device), y_te.to(device)

    print('test error:', 1.0 - net.predict(x_te).eq(y_te).float().mean().item())