In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR



In [28]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
        for x in [
            self.conv1,
            self.conv2,
            self.fc1,
            self.fc2,
        ]:
            x.register_backward_hook(clamp_grad)
            x.register_forward_pre_hook(clamp_input)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [133]:
UPSTREAM_GRAD_BOUND = 0.001
INPUT_BOUND = 5.
BATCH_SIZE=256

In [134]:
def clamp_grad(self, grad_input, grad_output):
#     if grad_input[0] != None:
#         print('max:', grad_input[0].abs().max(), 'mean:', grad_input[0].abs().mean())

    return tuple([None if x == None else x.clamp(-UPSTREAM_GRAD_BOUND, UPSTREAM_GRAD_BOUND) for x in grad_input])

In [135]:
def clamp_input(self, input):
    #print('max:', input[0].abs().max(), 'mean:', input[0].abs().mean())
    return tuple([x.clamp(-INPUT_BOUND, INPUT_BOUND) for x in input])

In [136]:
def train(model, train_loader, optimizer, epoch):
    total_rho = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data, target
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        
        #total_rho += add_noise(model)
        
        optimizer.step()
        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tRho: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), total_rho))
    print('Final rho:', total_rho)

In [137]:
def add_noise(model):
    rho_i = 0.0001
    total_rho = 0

    for p in model.parameters():
        print(p.shape)
        output_len, input_len = p.shape
        fake_inputs = torch.full((1, input_len), INPUT_BOUND)
        fake_upstream_grad = torch.full((1, output_len), UPSTREAM_GRAD_BOUND)
        fake_grad = fake_inputs.T @ fake_upstream_grad
        grad_bound = fake_grad.flatten().norm(p=2)
        
        sensitivity = grad_bound / BATCH_SIZE
        sigma = np.sqrt(sensitivity**2 / (2*rho_i))
        
        with torch.no_grad():
            p.grad += sigma*torch.randn(p.shape)
        total_rho += rho_i

    return total_rho

In [138]:
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data, target
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [139]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('data', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('data', train=False,
                   transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=BATCH_SIZE)

model = Net()
optimizer = optim.Adadelta(model.parameters())#, lr=args.lr)

In [140]:
scheduler = StepLR(optimizer, step_size=1)#, gamma=args.gamma)
for epoch in range(1, 2):
    train(model, train_loader, optimizer, epoch)
    test(model, test_loader)
    scheduler.step()

Final rho: 0

Test set: Average loss: 0.0626, Accuracy: 9800/10000 (98%)



In [122]:
t = nn.Conv2d(1, 32, 3, 1)
t(torch.ones((1, 1, 28, 28))).shape
t.weight.shape

torch.Size([32, 1, 3, 3])