In [None]:
import torch
import torch.nn.functional as F


device = torch.device("cuda")
latent_size = 16
label_size = 10


class MaxOneHot(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        idx = torch.argmax(input)
        ctx._input_shape = input.shape
        ctx._input_dtype = input.dtype
        ctx._input_device = input.device
        ctx.save_for_backward(idx)
        output = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
        output[idx] = input[idx]
        return output

    @staticmethod
    def backward(ctx, grad_output):
        idx, = ctx.saved_tensors
        grad_input = torch.zeros(ctx._input_shape, device=ctx._input_device, dtype=ctx._input_dtype)
        grad_input[idx] = grad_output[idx]
        return grad_input


class Softmax(torch.nn.Module):
    def __init__(self, latent_size, label_size):
        super(Softmax, self).__init__()

        self.fc1 = torch.nn.Linear(latent_size, label_size)

    def forward(self, x):
        return F.softmax(self.fc1(x), dim=1)


class Encoder(torch.nn.Module):
    def __init__(self, latent_size):
        super(Encoder, self).__init__()
        self.latent_size = latent_size

        self.conv1 = torch.nn.Conv2d(1, 32, 5, padding=2)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2)
        self.fc1 = torch.nn.Linear(64 * 28 * 28, latent_size)

    def forward(self, x):
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = x.view(-1, 64 * 28 * 28)   # reshape Variable
        x = self.fc1(x)
        return torch.tanh(x)


class Hopfield(torch.nn.Module):
    def __init__(self, latent_size, label_size):
        super(Hopfield, self).__init__()
        self.latent_size = latent_size
        self.label_size = label_size
        label_latent_vector = 2 * torch.rand((label_size, latent_size), device=device) - 1
        self.label_latent_vectors = label_latent_vector.clone().requires_grad_(True)
        self.max_one_hot = MaxOneHot.apply

    def forward(self, s):
        weight = self._get_weight()

        min_e = self._get_energy(s, weight)
        min_s = s

        for _ in range(self.latent_size):
            prev_s = s.clone()
            s = torch.sign(weight @ prev_s) * torch.abs(prev_s)
            e = self._get_energy(s, weight)
            if min_e > e:
                min_e = e
                min_s = s

        return self.max_one_hot(torch.abs(min_s @ self.label_latent_vectors.T)) / self.latent_size

    @staticmethod
    def _get_energy(s, w):
        return - s @ w @ s

    def _get_weight(self):

        x = self.label_latent_vectors
        rho = torch.mean(x)

        for i, x_ in enumerate(x):
            temp = x_ - rho
            if i == 0:
                weight = torch.ger(temp, temp)
            else:
                weight += torch.ger(temp, temp)

        diag_weight = torch.diag(torch.diag(weight))
        weight = weight - diag_weight
        weight /= len(x)

        return weight


class DeepHopfield(torch.nn.Module):
    def __init__(self, latent_size, label_size):
        super(DeepHopfield, self).__init__()
        self.encoder = Encoder(latent_size).to(device)
        self.softmax = Softmax(latent_size, label_size).to(device)
        self.hopfield = Hopfield(latent_size, label_size)
        self.label_size = label_size
        self.latent_size = latent_size

    def forward(self, x):
        latent_vectors = self.encoder(x)

        hopfield_labels = torch.stack([
            self.hopfield(latent_vector)
            for latent_vector in latent_vectors
        ])

        return hopfield_labels, self.softmax(latent_vectors)

    def optimizer(self):
        opt = \
            torch.optim.Adam([
                {'params': self.encoder.parameters()},
                {'params': self.softmax.parameters()},
                {'params': self.hopfield.label_latent_vectors}
            ], lr=0.01)
        return opt



In [None]:
if __name__ == "__main__":
    from torchvision import datasets, transforms

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()), batch_size=512, shuffle=True
    )

    model = DeepHopfield(16, 10)
    optimizer = model.optimizer()

    cross_entropy_loss_fn = torch.nn.CrossEntropyLoss().to(device)
    for step in range(10000):
        for b, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()

            label0, label1 = model(data)

            cross_entropy_loss0 = cross_entropy_loss_fn(label0, target)
            # cross_entropy_loss1 = cross_entropy_loss_fn(label1, target)
            loss = cross_entropy_loss0  # + cross_entropy_loss1

            loss.backward()
            print(f"[Epoch: {step: 05d} | {b: 02d} / {len(train_loader) - 1: 02d}] "
                  f"loss: {loss.data: .5f} grad: {torch.sum(torch.abs(model.hopfield.label_latent_vectors.grad))}", end='\r')
            optimizer.step()

        print("")