In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, Subset
from tqdm import tqdm
from PIL import Image
import numpy as np


num_epochs_target = 10
num_epochs_shadow = 10
num_epochs_attack = 20

# NOTE: this code is only tested on CPU, there may be some issues on GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

# define transform and load dataset
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

train_set = dsets.MNIST(root="./data", train=True, download=True, transform=transform)
train_set = TensorDataset(train_set.data, train_set.targets)
test_set = dsets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

device: cpu


In [2]:
split_size = len(train_set) // 4

data_train_shadow = Subset(train_set, range(0, split_size))
data_out_shadow = Subset(train_set, range(split_size, split_size * 2))
data_train_attack = Subset(train_set, range(0, split_size * 2))
data_train_target = Subset(train_set, range(split_size * 2, split_size * 3))
data_nonmember_target = Subset(train_set, range(split_size * 3, len(train_set)))
data_eval_attack = Subset(train_set, range(split_size * 2, len(train_set)))


# make sure the splitted dataset are transformed
class TransformedTensorDataset(TensorDataset):
    def __init__(self, data_tensor, target_tensor, transform=None):
        assert data_tensor.size(0) == target_tensor.size(0)
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
        self.transform = transform
        self.tensors = (data_tensor, target_tensor)

    def __getitem__(self, index):
        data = self.data_tensor[index]
        target = self.target_tensor[index]
        data = Image.fromarray(data.numpy(), mode="L")
        if self.transform:
            data = self.transform(data)
        return data, target

    def __len__(self):
        return self.data_tensor.size(0)


def subset_to_tensor(subset):
    return TransformedTensorDataset(
        torch.stack([subset[i][0] for i in range(len(subset))]),
        torch.tensor([subset[i][1] for i in range(len(subset))]),
        transform=transform,
    )


data_train_shadow = subset_to_tensor(data_train_shadow)
data_out_shadow = subset_to_tensor(data_out_shadow)
data_train_attack = subset_to_tensor(data_train_attack)
data_train_target = subset_to_tensor(data_train_target)
data_nonmember_target = subset_to_tensor(data_nonmember_target)
data_eval_attack = subset_to_tensor(data_eval_attack)

# Create labels for training of attack model
label_train_attack = torch.cat(
    (torch.ones(len(data_train_shadow)), torch.zeros(len(data_out_shadow))), dim=0
)
data_train_attack = TransformedTensorDataset(
    data_train_attack.tensors[0], label_train_attack, transform=transform
)

# create labels for evaluating the attack model
label_eval_attack = torch.cat(
    (torch.ones(len(data_train_target)), torch.zeros(len(data_nonmember_target))), dim=0
)
data_eval_attack = TransformedTensorDataset(
    data_eval_attack.tensors[0], label_eval_attack, transform=transform
)


loader_train_shadow = DataLoader(data_train_shadow, batch_size=64, shuffle=True)
loader_out_shadow = DataLoader(data_out_shadow, batch_size=64, shuffle=True)
loader_train_attack = DataLoader(data_train_attack, batch_size=64, shuffle=True)
loader_train_target = DataLoader(data_train_target, batch_size=64, shuffle=True)
loader_nonmember_target = DataLoader(data_nonmember_target, batch_size=64, shuffle=True)
loader_eval_attack = DataLoader(data_eval_attack, batch_size=64, shuffle=True)

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


target_model = CNN().to(device)
shadow_model = CNN().to(device)

In [4]:
optimizer_target = torch.optim.Adam(target_model.parameters(), lr=0.001)
loss = 999
for epoch in range(num_epochs_target):
    print(f"epoch {epoch}/{num_epochs_target} loss={loss}")
    loss = 0
    for i, (images, labels) in tqdm(enumerate(loader_train_target)):
        images = images.to(device)
        labels = labels.to(device)
        outputs_target = target_model(images)
        loss_target = F.nll_loss(outputs_target, labels)
        optimizer_target.zero_grad()
        loss_target.backward()
        optimizer_target.step()
        loss += loss_target
target_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in loader_train_target:
        images = images.to(device)
        labels = labels.to(device)
        target = target_model(images)
        _, prediction = torch.max(target, dim=1)
        correct += (prediction == labels).sum().item()
        total += labels.size(0)
    print(f"target model accu (train set): {correct}/{total}={correct/total*100:.2f}%")
correct = 0
total = 0
with torch.no_grad():
    for images, labels in loader_nonmember_target:
        images = images.to(device)
        labels = labels.to(device)
        target = target_model(images)
        _, prediction = torch.max(target, dim=1)
        correct += (prediction == labels).sum().item()
        total += labels.size(0)
    print(f"target model accu (test set): {correct}/{total}={correct/total*100:.2f}%")

epoch 0/10 loss=999


235it [00:04, 58.24it/s]


epoch 1/10 loss=135.6267547607422


235it [00:03, 59.78it/s]


epoch 2/10 loss=34.335731506347656


235it [00:04, 58.61it/s]


epoch 3/10 loss=23.78925323486328


235it [00:03, 60.28it/s]


epoch 4/10 loss=17.746421813964844


235it [00:03, 61.39it/s]


epoch 5/10 loss=14.4429349899292


235it [00:03, 59.88it/s]


epoch 6/10 loss=11.743385314941406


235it [00:03, 60.17it/s]


epoch 7/10 loss=10.147985458374023


235it [00:03, 60.49it/s]


epoch 8/10 loss=7.657524585723877


235it [00:03, 61.59it/s]


epoch 9/10 loss=6.925703048706055


235it [00:03, 60.20it/s]


target model accu (train set): 14938/15000=99.59%
target model accu (test set): 14700/15000=98.00%


In [5]:
optimizer_shadow = torch.optim.Adam(shadow_model.parameters(), lr=0.001)
loss = 999
for epoch in range(num_epochs_shadow):
    print(f"epoch {epoch}/{num_epochs_target} loss={loss}")
    loss = 0
    for i, (images, labels) in tqdm(enumerate(loader_train_shadow)):
        images = images.to(device)
        labels = labels.to(device)
        outputs_shadow = shadow_model(images)
        loss_shadow = F.nll_loss(outputs_shadow, labels)
        optimizer_shadow.zero_grad()
        loss_shadow.backward()
        optimizer_shadow.step()
        loss += loss_shadow
shadow_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in loader_nonmember_target:
        images = images.to(device)
        labels = labels.to(device)
        target = shadow_model(images)
        _, prediction = torch.max(target, dim=1)
        correct += (prediction == labels).sum().item()
        total += labels.size(0)
    print(f"shadow model accu: {correct}/{total}={correct/total*100:.2f}%")

epoch 0/10 loss=999


235it [00:03, 60.01it/s]


epoch 1/10 loss=142.8603973388672


235it [00:03, 61.31it/s]


epoch 2/10 loss=37.368621826171875


235it [00:03, 59.55it/s]


epoch 3/10 loss=24.2596435546875


235it [00:03, 60.82it/s]


epoch 4/10 loss=19.069618225097656


235it [00:03, 61.15it/s]


epoch 5/10 loss=15.747686386108398


235it [00:03, 61.16it/s]


epoch 6/10 loss=12.181062698364258


235it [00:03, 61.50it/s]


epoch 7/10 loss=10.698138236999512


235it [00:03, 61.67it/s]


epoch 8/10 loss=9.221597671508789


235it [00:03, 61.71it/s]


epoch 9/10 loss=8.278692245483398


235it [00:03, 59.47it/s]


shadow model accu: 14654/15000=97.69%


In [10]:
class AttackModel(nn.Module):
    def __init__(self):
        super(AttackModel, self).__init__()
        self.fc1 = nn.Linear(3, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x


attack_model = AttackModel().to(device)

In [11]:
optimizer_attack = torch.optim.Adam(attack_model.parameters(), lr=0.001)
loss = 999
for epoch in range(num_epochs_attack):
    print(f"epoch {epoch}/{num_epochs_attack}, loss={loss}")
    loss = 0
    for idx, (images, labels) in tqdm(enumerate(loader_train_attack)):
        images = images.to(device)
        labels = labels.to(device)
        posteriors_shadow = F.softmax(shadow_model(images), dim=1)
        # print(posteriors_shadow)
        top3_posteriors = torch.topk(posteriors_shadow, 3, dim=1)[0]
        # print(top3_posteriors)
        labels_attack = labels.float().unsqueeze(1)

        # attack_outputs = attack_model(posteriors_shadow)
        attack_outputs = attack_model(top3_posteriors)
        loss_attack = F.binary_cross_entropy(attack_outputs, labels_attack)

        optimizer_attack.zero_grad()
        loss_attack.backward()
        optimizer_attack.step()
        loss += loss_attack

epoch 0/20, loss=999


469it [00:07, 59.35it/s]


epoch 1/20, loss=325.06988525390625


469it [00:08, 58.22it/s]


epoch 2/20, loss=324.98016357421875


469it [00:07, 60.86it/s]


epoch 3/20, loss=324.8493347167969


469it [00:07, 59.14it/s]


epoch 4/20, loss=324.95257568359375


469it [00:07, 60.71it/s]


epoch 5/20, loss=324.9265441894531


469it [00:07, 61.01it/s]


epoch 6/20, loss=324.914794921875


469it [00:07, 59.89it/s]


epoch 7/20, loss=324.9212951660156


469it [00:07, 60.59it/s]


epoch 8/20, loss=324.87933349609375


469it [00:07, 59.53it/s]


epoch 9/20, loss=324.89532470703125


469it [00:08, 57.68it/s]


epoch 10/20, loss=324.9516906738281


469it [00:07, 60.29it/s]


epoch 11/20, loss=324.8806457519531


469it [00:07, 60.76it/s]


epoch 12/20, loss=324.85443115234375


469it [00:07, 60.92it/s]


epoch 13/20, loss=324.8356628417969


469it [00:07, 60.49it/s]


epoch 14/20, loss=324.8828430175781


469it [00:07, 60.35it/s]


epoch 15/20, loss=324.91351318359375


469it [00:07, 61.37it/s]


epoch 16/20, loss=324.8971862792969


469it [00:07, 61.73it/s]


epoch 17/20, loss=324.84417724609375


469it [00:07, 61.57it/s]


epoch 18/20, loss=324.8438720703125


469it [00:07, 61.58it/s]


epoch 19/20, loss=324.8574523925781


469it [00:07, 61.96it/s]


In [12]:
attack_model.eval()
total = 0

all_pred = torch.empty((1))
all_labels = torch.empty((1))

with torch.no_grad():
    for idx, (images, labels) in enumerate(loader_eval_attack):
        images = images.to(device)
        labels = labels.to(device)
        posteriors_target = F.softmax(target_model(images), dim=1)
        top3_posteriors = torch.topk(posteriors_target, 3, dim=1)[0]

        attack_outputs = attack_model(top3_posteriors)

        all_pred = torch.cat((all_pred, attack_outputs.squeeze()), dim=0)
        all_labels = torch.cat((all_labels, labels.squeeze()), dim=0)

In [14]:
for pred_thre in np.arange(0.30, 0.60, 0.02):
    all_pred_bin = all_pred > pred_thre
    all_labels_bin = all_labels > 0.5  # only 0 or 1

    TP = (all_pred_bin & all_labels_bin).sum().item()
    TN = ((~all_pred_bin) & (~all_labels_bin)).sum().item()
    FP = (all_pred_bin & (~all_labels_bin)).sum().item()
    FN = ((~all_pred_bin) & all_labels_bin).sum().item()

    print(
        f"Attack Model Accuracy for thre={pred_thre:.2f} TP={TP} TN={TN} FP={FP} FN={FN} ",
        end="",
    )
    if (TP + FP) == 0 or (TP + FN) == 0:
        print("metrics invalid")
        continue
    accu = (TP + TN) / (TP + TN + FP + FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1 = 2 * (precision * recall) / (precision + recall)
    print(f"accu={accu:.2f} " f"prec={precision:.2f} recall={recall:.2f} f1={f1:.2f}")

Attack Model Accuracy for thre=0.30 TP=14999 TN=9 FP=14992 FN=1 accu=0.50 prec=0.50 recall=1.00 f1=0.67
Attack Model Accuracy for thre=0.32 TP=14995 TN=19 FP=14982 FN=5 accu=0.50 prec=0.50 recall=1.00 f1=0.67
Attack Model Accuracy for thre=0.34 TP=14986 TN=29 FP=14972 FN=14 accu=0.50 prec=0.50 recall=1.00 f1=0.67
Attack Model Accuracy for thre=0.36 TP=14975 TN=51 FP=14950 FN=25 accu=0.50 prec=0.50 recall=1.00 f1=0.67
Attack Model Accuracy for thre=0.38 TP=14961 TN=92 FP=14909 FN=39 accu=0.50 prec=0.50 recall=1.00 f1=0.67
Attack Model Accuracy for thre=0.40 TP=14919 TN=169 FP=14832 FN=81 accu=0.50 prec=0.50 recall=0.99 f1=0.67
Attack Model Accuracy for thre=0.42 TP=14863 TN=275 FP=14726 FN=137 accu=0.50 prec=0.50 recall=0.99 f1=0.67
Attack Model Accuracy for thre=0.44 TP=14807 TN=401 FP=14600 FN=193 accu=0.51 prec=0.50 recall=0.99 f1=0.67
Attack Model Accuracy for thre=0.46 TP=14696 TN=511 FP=14490 FN=304 accu=0.51 prec=0.50 recall=0.98 f1=0.67
Attack Model Accuracy for thre=0.48 TP=145