1. Imports

In [7]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from opacus import PrivacyEngine
import numpy as np

2. Model Definition

In [8]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.Flatten(),
        )
        sample_input = torch.zeros(1, 1, 28, 28)
        sample_output = self.features(sample_input)
        num_features = sample_output.shape[1]

        self.classifier = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

3. Dataset Definition

In [9]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9852248.71it/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 673121.92it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 5383139.36it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 25640011.80it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






4. Define Modified Dataset and Loaders

In [10]:
train_dataset_removed = Subset(train_dataset, list(range(1, len(train_dataset))))

train_loader_full = DataLoader(train_dataset, batch_size=64, shuffle=True)
train_loader_removed = DataLoader(train_dataset_removed, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

5. Define Model Training and Predictions

In [11]:
def train(model, dataloader, epsilon, delta=1e-5, epochs=1):
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    privacy_engine = PrivacyEngine()
    model, optimizer, dataloader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=dataloader,
        noise_multiplier=1.0,
        max_grad_norm=1.0,
    )
    
    model.train()
    for epoch in range(epochs):
        for data, target in dataloader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model
    
def get_predictions(model, dataloader):
    model.eval()
    preds = []
    with torch.no_grad():
        for data, _ in dataloader:
            output = model(data)
            preds.append(output.softmax(dim=1).cpu().numpy())
    return np.vstack(preds)

6. Model Training

In [12]:

model_full = SimpleCNN()
model_full = train(model_full, train_loader_full, epsilon=1.0)

model_removed = SimpleCNN()
model_removed = train(model_removed, train_loader_removed, epsilon=1.0)




7. Comparison of Outputs 

In [None]:
preds_full = get_predictions(model_full, test_loader)
preds_removed = get_predictions(model_removed, test_loader)

# Compute average absolute difference in predictions
#Show that the difference is negligible 
difference = np.abs(preds_full - preds_removed).mean()
print(f"Average difference in predictions: {difference:.6f}")
print(f"Full Prediction Mean: {preds_full}")
print(f"Removed Prediction Mean: {preds_removed}")

Average difference in predictions: 0.020119
Full Prediction Mean: [[1.16137744e-09 1.09216256e-16 1.58304394e-11 ... 9.99998808e-01
  2.88962582e-12 1.96775673e-07]
 [1.26322730e-09 2.23739707e-10 9.99986053e-01 ... 5.85302076e-22
  9.58533633e-07 1.87259794e-18]
 [8.71668224e-11 9.99991655e-01 1.14783165e-06 ... 8.02043473e-07
  4.28312248e-08 1.01641922e-08]
 ...
 [1.72753259e-15 3.54952802e-13 3.81870506e-14 ... 2.70993716e-08
  3.41880835e-08 3.72739196e-05]
 [2.78639112e-09 5.40264264e-05 1.64589253e-09 ... 5.71252867e-09
  2.42986381e-01 6.50359055e-08]
 [1.76302428e-11 1.33886705e-21 1.27881628e-06 ... 1.98587529e-24
  1.26293968e-13 1.79071844e-16]]
Removed Prediction Mean: [[6.7381331e-12 7.2091879e-16 5.1216786e-14 ... 9.9999976e-01
  9.4268820e-13 2.1180117e-07]
 [3.6349634e-04 8.4612948e-13 9.9940252e-01 ... 3.4940552e-20
  1.1668764e-06 4.6668739e-17]
 [8.3997147e-09 9.9998653e-01 2.5193992e-06 ... 3.8106460e-08
  6.1010264e-06 7.7801104e-08]
 ...
 [3.4553022e-14 3.9091008