In [1]:
import torch
import torch.utils.data
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms


import ignite
from ignite.engine import Events, Engine
from ignite.metrics import Accuracy, Loss

import numpy as np
import sklearn.datasets
from tqdm.notebook import trange, tqdm

import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

## MODEL

In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(128, 128, 3)
        self.bn3 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(2 * 2 * 128, 256)

    def compute_features(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)

        x = x.flatten(1)

        x = F.relu(self.fc1(x))

        return x


class CNN_DUQ(Model):
    def __init__(
        self,
        input_size,
        num_classes,
        embedding_size,
        learnable_length_scale,
        length_scale,
        gamma,
    ):
        super().__init__()

        self.gamma = gamma

        self.W = nn.Parameter(
            torch.normal(torch.zeros(embedding_size, num_classes, 256), 0.05)
        )

        self.register_buffer("N", torch.ones(num_classes) * 12)
        self.register_buffer(
            "m", torch.normal(torch.zeros(embedding_size, num_classes), 1)
        )

        self.m = self.m * self.N.unsqueeze(0)

        if learnable_length_scale:
            self.sigma = nn.Parameter(torch.zeros(num_classes) + length_scale)
        else:
            self.sigma = length_scale

    def update_embeddings(self, x, y):
        z = self.last_layer(self.compute_features(x))

        # normalizing value per class, assumes y is one_hot encoded
        self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0)

        # compute sum of embeddings on class by class basis
        features_sum = torch.einsum("ijk,ik->jk", z, y)

        self.m = self.gamma * self.m + (1 - self.gamma) * features_sum

    def last_layer(self, z):
        z = torch.einsum("ij,mnj->imn", z, self.W)
        return z

    def output_layer(self, z):
        embeddings = self.m / self.N.unsqueeze(0)

        diff = z - embeddings.unsqueeze(0)
        distances = (-(diff ** 2)).mean(1).div(2 * self.sigma ** 2).exp()

        return distances

    def forward(self, x):
        z = self.last_layer(self.compute_features(x))
        y_pred = self.output_layer(z)

        return z, y_pred

In [3]:
input_size = 28
num_classes = 10
embedding_size = 256
learnable_length_scale = False
gamma = 0.999
# length_scales = [0.05, 0.1, 0.2, 0.3, 0.5, 1.0]
length_scale = 0.3

model = CNN_DUQ(
    input_size,
    num_classes,
    embedding_size,
    learnable_length_scale,
    length_scale,
    gamma,
)

optimizer = torch.optim.SGD(
    model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4
)

## Dataset

In [4]:
transform = torchvision.transforms.ToTensor()

train_dataset = torchvision.datasets.MNIST("./", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST("./", download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, num_workers=4, shuffle=False)

In [5]:
dataloader_iterator = iter(train_loader)
x, y = next(dataloader_iterator)
y = F.one_hot(y, num_classes=10).float()
y[0]
# img = x[2].view(28, 28).data

# plt.figure()
# plt.imshow(img, cmap='gray')
# plt.show()   
# type(x[0].shape)
# print(x[0])
a = model(x)
# a ** 2
# x torch.Size([128, 1, 28, 28])
# z torch.Size([128, 256, 10])
# DIFF torch.Size([128, 256, 10])

### LOSS

In [6]:
def calc_gradient_penalty(x, y_pred_sum):
    gradients = torch.autograd.grad(
        outputs=y_pred_sum,
        inputs=x,
        grad_outputs=torch.ones_like(y_pred_sum),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradients = gradients.flatten(start_dim=1)

    # L2 norm
    grad_norm = gradients.norm(2, dim=1)

    # Two sided penalty
    gradient_penalty = ((grad_norm - 1) ** 2).mean()

    return gradient_penalty

## Traning 

In [9]:
# def eval(epoch):
#     model.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for i, (data, labels) in enumerate(test_loader):
#             recon_batch, mu, logvar = model(data.reshape(-1, 784), idx2onehot(labels.view(-1, 1)))
#             test_loss += loss_function(recon_batch, data, mu, logvar).item()
  




epochs = 30
l_gradient_penalty = 0.05

for epoch in range(epochs):
    loss = 0
    bce = 0
    GP = 0
    for i, batch in enumerate(train_loader):
#         with tqdm(total=len(train_loader), desc='Training epoch {}'.format(2)) as p_bar:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model.train()
        optimizer.zero_grad()

        x, y = batch
        y = F.one_hot(y, num_classes=10).float()

        x, y = x.to(device), y.to(device)

        x.requires_grad_(True)

        z, y_pred = model(x)
        bce_loss = F.binary_cross_entropy(y_pred, y)
        GP_loss = l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))
        train_loss = bce_loss + GP_loss
        x.requires_grad_(False)

        train_loss.backward()
        optimizer.step()
        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)   
        loss += train_loss.item()
        bce += bce_loss.item()
        GP += GP_loss.item()
        print("iteration : {}/{} LOSS : {}".format(i, len(train_loader), train_loss.item()))
#         pbar.update(1)
    loss = loss/len(train_loader)
    bce = bce/len(train_loader)
    GP = GP/len(train_loader)
    
#     model.eval()
#     eval(epoch)
    
#     print(
#         f"Validation Results - Epoch: {trainer.state.epoch} "
#         f"Acc: {metrics['accuracy']:.4f} "
#         f"BCE: {metrics['bce']:.2f} "
#         f"GP: {metrics['gradient_penalty']:.6f} "
#         f"AUROC MNIST: {roc_auc_mnist:.2f} "
#         f"AUROC NotMNIST: {roc_auc_notmnist:.2f} "
#     )
#     print(f"Sigma: {model.sigma}")
    print("EPOCH {}/{}, loss = {:.6f}, kl_loss = {:.6f}, bce_loss = {:.6f}".format(epoch+1, epochs, loss, bce, GP))

    torch.save(model.state_dict(), './duq_mnist_{}.pth'.format(epoch+1))

iteration : 0/469 LOSS : 0.023798281326889992
iteration : 1/469 LOSS : 0.03047027997672558
iteration : 2/469 LOSS : 0.026705896481871605


Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.5/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [10]:
model.load_state_dict(torch.load('./cvae.pth'))

<All keys matched successfully>