In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, Subset
from tqdm import tqdm
from PIL import Image
import numpy as np


num_epochs_target = 10
num_epochs_shadow = 10
num_epochs_attack = 10

# NOTE: this code is only tested on CPU, there may be some issues on GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

# define transform and load dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

train_set = dsets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_set = TensorDataset(torch.tensor(train_set.data), torch.tensor(train_set.targets))
test_set = dsets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

device: cpu
Files already downloaded and verified
Files already downloaded and verified


In [2]:
split_size = len(train_set) // 4

data_train_shadow = Subset(train_set, range(0, split_size))
data_out_shadow = Subset(train_set, range(split_size, split_size * 2))
data_train_attack = Subset(train_set, range(0, split_size * 2))
data_train_target = Subset(train_set, range(split_size * 2, split_size * 3))
data_nonmember_target = Subset(train_set, range(split_size * 3, len(train_set)))
data_eval_attack = Subset(train_set, range(split_size * 2, len(train_set)))


# make sure the splitted dataset are transformed
class TransformedTensorDataset(TensorDataset):
    def __init__(self, data_tensor, target_tensor, transform=None):
        assert data_tensor.size(0) == target_tensor.size(0)
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
        self.transform = transform
        self.tensors = (data_tensor, target_tensor)

    def __getitem__(self, index):
        data = self.data_tensor[index]
        target = self.target_tensor[index]
        data = Image.fromarray(data.numpy())
        if self.transform:
            data = self.transform(data)
        return data, target

    def __len__(self):
        return self.data_tensor.size(0)


def subset_to_tensor(subset):
    return TransformedTensorDataset(
        torch.stack([subset[i][0] for i in range(len(subset))]),
        torch.tensor([subset[i][1] for i in range(len(subset))]),
        transform=transform,
    )


data_train_shadow = subset_to_tensor(data_train_shadow)
data_out_shadow = subset_to_tensor(data_out_shadow)
data_train_attack = subset_to_tensor(data_train_attack)
data_train_target = subset_to_tensor(data_train_target)
data_nonmember_target = subset_to_tensor(data_nonmember_target)
data_eval_attack = subset_to_tensor(data_eval_attack)

# Create labels for training of attack model
label_train_attack = torch.cat(
    (torch.ones(len(data_train_shadow)), torch.zeros(len(data_out_shadow))), dim=0
)
data_train_attack = TransformedTensorDataset(
    data_train_attack.tensors[0], label_train_attack, transform=transform
)

# create labels for evaluating the attack model
label_eval_attack = torch.cat(
    (torch.ones(len(data_train_target)), torch.zeros(len(data_nonmember_target))), dim=0
)
data_eval_attack = TransformedTensorDataset(
    data_eval_attack.tensors[0], label_eval_attack, transform=transform
)


loader_train_shadow = DataLoader(data_train_shadow, batch_size=64, shuffle=True)
loader_out_shadow = DataLoader(data_out_shadow, batch_size=64, shuffle=True)
loader_train_attack = DataLoader(data_train_attack, batch_size=64, shuffle=True)
loader_train_target = DataLoader(data_train_target, batch_size=64, shuffle=True)
loader_nonmember_target = DataLoader(data_nonmember_target, batch_size=64, shuffle=True)
loader_eval_attack = DataLoader(data_eval_attack, batch_size=64, shuffle=True)

In [3]:
class CNN(nn.Module):
    def __init__(self, input_size=3):
        super(CNN, self).__init__()
        self.input_size = input_size
        self.conv1 = nn.Conv2d(
            in_channels=input_size, out_channels=48, kernel_size=(3, 3)
        )
        self.conv2 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=(3, 3))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        if input_size == 3:
            self.fc_features = 6 * 6 * 96
        else:
            self.fc_features = 5 * 5 * 96
        self.fc1 = nn.Linear(in_features=self.fc_features, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        x = x.view(-1, self.fc_features)  # reshape x
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


target_model = CNN().to(device)
shadow_model = CNN().to(device)
target_loss = nn.CrossEntropyLoss()
shadow_loss = nn.CrossEntropyLoss()

In [4]:
optimizer_target = torch.optim.Adam(target_model.parameters(), lr=0.001)
loss = 999
for epoch in range(num_epochs_target):
    print(f"epoch {epoch}/{num_epochs_target} loss={loss}")
    loss = 0
    for i, (images, labels) in tqdm(enumerate(loader_train_target)):
        images = images.to(device)
        labels = labels.to(device)
        outputs_target = target_model(images)
        loss_target = target_loss(outputs_target, labels)
        optimizer_target.zero_grad()
        loss_target.backward()
        optimizer_target.step()
        loss += loss_target
target_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in loader_train_target:
        images = images.to(device)
        labels = labels.to(device)
        target = target_model(images)
        _, prediction = torch.max(target, dim=1)
        correct += (prediction == labels).sum().item()
        total += labels.size(0)
    print(f"target model accu (train set): {correct}/{total}={correct/total*100:.2f}%")
correct = 0
total = 0
with torch.no_grad():
    for images, labels in loader_nonmember_target:
        images = images.to(device)
        labels = labels.to(device)
        target = target_model(images)
        _, prediction = torch.max(target, dim=1)
        correct += (prediction == labels).sum().item()
        total += labels.size(0)
    print(f"target model accu (test set): {correct}/{total}={correct/total*100:.2f}%")

epoch 0/10 loss=999


196it [00:11, 17.03it/s]


epoch 1/10 loss=338.8419189453125


196it [00:11, 17.17it/s]


epoch 2/10 loss=265.96044921875


196it [00:11, 17.02it/s]


epoch 3/10 loss=228.33941650390625


196it [00:11, 17.05it/s]


epoch 4/10 loss=198.8133544921875


196it [00:11, 16.82it/s]


epoch 5/10 loss=164.3293914794922


196it [00:11, 16.97it/s]


epoch 6/10 loss=133.46311950683594


196it [00:11, 16.95it/s]


epoch 7/10 loss=102.7719955444336


196it [00:11, 17.00it/s]


epoch 8/10 loss=70.64253234863281


196it [00:11, 16.99it/s]


epoch 9/10 loss=45.383392333984375


196it [00:11, 16.99it/s]


target model accu (train set): 12086/12500=96.69%
target model accu (test set): 7738/12500=61.90%


In [5]:
optimizer_shadow = torch.optim.Adam(shadow_model.parameters(), lr=0.001)
loss = 999
for epoch in range(num_epochs_shadow):
    print(f"epoch {epoch}/{num_epochs_target} loss={loss}")
    loss = 0
    for i, (images, labels) in tqdm(enumerate(loader_train_shadow)):
        images = images.to(device)
        labels = labels.to(device)
        outputs_shadow = shadow_model(images)
        loss_shadow = shadow_loss(outputs_shadow, labels)
        optimizer_shadow.zero_grad()
        loss_shadow.backward()
        optimizer_shadow.step()
        loss += loss_shadow
shadow_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in loader_nonmember_target:
        images = images.to(device)
        labels = labels.to(device)
        target = shadow_model(images)
        _, prediction = torch.max(target, dim=1)
        correct += (prediction == labels).sum().item()
        total += labels.size(0)
    print(f"shadow model accu: {correct}/{total}={correct/total*100:.2f}%")

epoch 0/10 loss=999


196it [00:11, 16.67it/s]


epoch 1/10 loss=335.2061462402344


196it [00:11, 17.25it/s]


epoch 2/10 loss=266.6365966796875


196it [00:11, 17.04it/s]


epoch 3/10 loss=226.57061767578125


196it [00:11, 17.02it/s]


epoch 4/10 loss=194.8171844482422


196it [00:11, 17.12it/s]


epoch 5/10 loss=163.1771240234375


196it [00:11, 16.98it/s]


epoch 6/10 loss=132.09356689453125


196it [00:11, 17.26it/s]


epoch 7/10 loss=98.14667510986328


196it [00:11, 17.31it/s]


epoch 8/10 loss=66.05852508544922


196it [00:11, 17.03it/s]


epoch 9/10 loss=41.381263732910156


196it [00:11, 17.17it/s]


shadow model accu: 7715/12500=61.72%


In [6]:
class AttackModel(nn.Module):
    def __init__(self):
        super(AttackModel, self).__init__()
        self.fc1 = nn.Linear(3, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x


attack_model = AttackModel().to(device)

In [7]:
optimizer_attack = torch.optim.Adam(attack_model.parameters(), lr=0.001)
loss = 999
for epoch in range(num_epochs_attack):
    print(f"epoch {epoch}/{num_epochs_attack}, loss={loss}")
    loss = 0
    for idx, (images, labels) in tqdm(enumerate(loader_train_attack)):
        images = images.to(device)
        labels = labels.to(device)
        posteriors_shadow = F.softmax(shadow_model(images), dim=1)
        top3_posteriors = torch.topk(posteriors_shadow, 3, dim=1)[0]
        labels_attack = labels.float().unsqueeze(1)

        # attack_outputs = attack_model(posteriors_shadow)
        attack_outputs = attack_model(top3_posteriors)
        loss_attack = F.binary_cross_entropy(attack_outputs, labels_attack)

        optimizer_attack.zero_grad()
        loss_attack.backward()
        optimizer_attack.step()
        loss += loss_attack

epoch 0/10, loss=999


391it [00:21, 18.35it/s]


epoch 1/10, loss=262.6524963378906


391it [00:22, 17.29it/s]


epoch 2/10, loss=254.68312072753906


391it [00:24, 16.17it/s]


epoch 3/10, loss=254.04678344726562


391it [00:24, 16.09it/s]


epoch 4/10, loss=253.76869201660156


391it [00:22, 17.12it/s]


epoch 5/10, loss=253.57717895507812


391it [00:25, 15.11it/s]


epoch 6/10, loss=253.39329528808594


391it [00:35, 10.91it/s]


epoch 7/10, loss=253.30189514160156


391it [00:34, 11.39it/s]


epoch 8/10, loss=253.22560119628906


391it [00:34, 11.34it/s]


epoch 9/10, loss=253.23565673828125


391it [00:34, 11.43it/s]


In [8]:
attack_model.eval()
total = 0

all_pred = torch.empty((1))
all_labels = torch.empty((1))

with torch.no_grad():
    for images, labels in loader_eval_attack:
        images = images.to(device)
        labels = labels.to(device)
        posteriors_target = F.softmax(target_model(images), dim=1)
        top3_posteriors = torch.topk(posteriors_target, 3, dim=1)[0]

        # attack_outputs = attack_model(posteriors_target)
        attack_outputs = attack_model(top3_posteriors)

        all_pred = torch.cat((all_pred, attack_outputs.squeeze()), dim=0)
        all_labels = torch.cat((all_labels, labels.squeeze()), dim=0)

In [11]:
for pred_thre in np.arange(0.20, 0.65, 0.05):
    all_pred_bin = all_pred > pred_thre
    all_labels_bin = all_labels > 0.5  # only 0 or 1

    TP = (all_pred_bin & all_labels_bin).sum().item()
    TN = ((~all_pred_bin) & (~all_labels_bin)).sum().item()
    FP = (all_pred_bin & (~all_labels_bin)).sum().item()
    FN = ((~all_pred_bin) & all_labels_bin).sum().item()

    print(
        f"Attack Model Accuracy for thre={pred_thre:.2f} TP={TP} TN={TN} FP={FP} FN={FN} ",
        end="",
    )
    if (TP + FP) == 0 or (TP + FN) == 0:
        print("metrics invalid")
        continue
    accu = (TP + TN) / (TP + TN + FP + FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1 = 2 * (precision * recall) / (precision + recall)
    print(f"accu={accu:.2f} " f"prec={precision:.2f} recall={recall:.2f} f1={f1:.2f}")

Attack Model Accuracy for thre=0.20 TP=12334 TN=927 FP=11574 FN=166 accu=0.53 prec=0.52 recall=0.99 f1=0.68
Attack Model Accuracy for thre=0.25 TP=12132 TN=1690 FP=10811 FN=368 accu=0.55 prec=0.53 recall=0.97 f1=0.68
Attack Model Accuracy for thre=0.30 TP=11770 TN=2564 FP=9937 FN=730 accu=0.57 prec=0.54 recall=0.94 f1=0.69
Attack Model Accuracy for thre=0.35 TP=11389 TN=3468 FP=9033 FN=1111 accu=0.59 prec=0.56 recall=0.91 f1=0.69
Attack Model Accuracy for thre=0.40 TP=10971 TN=4252 FP=8249 FN=1529 accu=0.61 prec=0.57 recall=0.88 f1=0.69
Attack Model Accuracy for thre=0.45 TP=10642 TN=4806 FP=7695 FN=1858 accu=0.62 prec=0.58 recall=0.85 f1=0.69
Attack Model Accuracy for thre=0.50 TP=10192 TN=5412 FP=7089 FN=2308 accu=0.62 prec=0.59 recall=0.82 f1=0.68
Attack Model Accuracy for thre=0.55 TP=9543 TN=6121 FP=6380 FN=2957 accu=0.63 prec=0.60 recall=0.76 f1=0.67
Attack Model Accuracy for thre=0.60 TP=8156 TN=7322 FP=5179 FN=4344 accu=0.62 prec=0.61 recall=0.65 f1=0.63
