In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from torch.optim import Adam

In [2]:
def get_y_neg(y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(10))
        allowed_indices.remove(y_samp.item())
        y_neg[idx] = torch.tensor(allowed_indices)[
            torch.randint(len(allowed_indices), size=(1,))
        ].item()
    return y_neg.to(device)


def overlay_y_on_x(x, y, classes=10):
    x_ = x.clone()
    x_[:, :classes] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

In [3]:
class Net(torch.nn.Module):
    def __init__(self, dims):

        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers = self.layers + [Layer(dims[d], dims[d + 1]).to(device)]

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness = goodness + [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print("training layer: ", i)
            h_pos, h_neg = layer.train(h_pos, h_neg)

In [10]:
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.01)
        self.threshold = 0.5
        self.num_epochs = 20

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(torch.mm(x_direction, self.weight.T) + self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        for i in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            loss = torch.log1p(
                torch.exp(
                    torch.cat([-g_pos + self.threshold, g_neg - self.threshold])
                )
            ).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            print("Loss: ", loss.item())
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [None]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

train_kwargs = {"batch_size": 32}
test_kwargs = {"batch_size": 32}

if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [None]:
transform = Compose(
    [
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x)),
    ]
)
train_loader = DataLoader(
    MNIST("./data/", train=True, download=True, transform=transform), **train_kwargs
)
test_loader = DataLoader(
    MNIST("./data/", train=False, download=True, transform=transform), **test_kwargs
)

In [11]:
net = Net([784, 500, 500])

x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)

x_pos = overlay_y_on_x(x, y)
y_neg = get_y_neg(y)
x_neg = overlay_y_on_x(x, y_neg)
net.train(x_pos, x_neg)

print("train error:", 1.0 - net.predict(x).eq(y).float().mean().item())
x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.to(device), y_te.to(device)
print("test error:", 1.0 - net.predict(x_te).eq(y_te).float().mean().item())

training layer:  0
Loss:  0.7240239381790161
Loss:  0.7223004102706909
Loss:  0.7182900905609131
Loss:  0.7118580341339111
Loss:  0.7039902806282043
Loss:  0.6968176364898682
Loss:  0.6937904953956604
Loss:  0.6971739530563354
Loss:  0.7003540992736816
Loss:  0.6988317370414734
Loss:  0.6952486038208008
Loss:  0.6924381256103516
Loss:  0.6914573311805725
Loss:  0.691879153251648
Loss:  0.6927556395530701
Loss:  0.6933597326278687
Loss:  0.6933772563934326
Loss:  0.6928091645240784
Loss:  0.691849946975708
Loss:  0.6908039450645447
training layer:  1
Loss:  0.7239922881126404
Loss:  0.7204704880714417
Loss:  0.7137596607208252
Loss:  0.7045049667358398
Loss:  0.6958726644515991
Loss:  0.6934424638748169
Loss:  0.6998155117034912
Loss:  0.7017667293548584
Loss:  0.6981872320175171
Loss:  0.6944157481193542
Loss:  0.6930897831916809
Loss:  0.69392329454422
Loss:  0.6954567432403564
Loss:  0.6965612173080444
Loss:  0.696794331073761
Loss:  0.6961874961853027
Loss:  0.6950468420982361
Loss: