In [1]:
import torch
import requests
import numpy as np
import pandas as pd
import torch.nn as nn
import torchvision.transforms as transforms
from typing import Tuple
from torch.utils.data import Dataset
from torchvision.models import resnet18
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [2]:
main_model = resnet18(pretrained=False)
main_model.fc = torch.nn.Linear(512, 44)
ckpt = torch.load("out/models/attack_model.pt", map_location=device)
main_model.load_state_dict(ckpt)
main_model.eval()

  ckpt = torch.load("out/models/attack_model.pt", map_location=device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
transform_00 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
])
transform_01 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
])
transform_10 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])
transform_11 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])

In [4]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label
    def __len__(self):
        return len(self.ids)

In [5]:
class MembershipDataset(TaskDataset):
    def __init__(self, transform=None):
        super().__init__(transform)
        self.membership = []
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int, int]:
        id_, img, label = super().__getitem__(index)
        return id_, img, label, self.membership[index]

In [6]:
priv_data = torch.load("out/data/priv.pt")
pub_data = torch.load("out/data/pub.pt")

  priv_data = torch.load("out/data/priv.pt")
  pub_data = torch.load("out/data/pub.pt")


## Simple Attacks
Collected from https://www.usenix.org/system/files/sec22fall_tang.pdf

In [7]:
class basic_attack_model(nn.Module):
    def __init__(self, layer_sizes, num_classes, output_size):
        super(basic_attack_model, self).__init__()
        self.fcs = nn.ModuleList()
        self.embedding = nn.Embedding(num_classes, layer_sizes[0])
        for i in range(len(layer_sizes)-1):
            self.fcs.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
                
    def forward(self, x, l):
        #print(x.shape, self.embedding(l).shape)
        input = x
        #print(input.shape)
        for i, fc in enumerate(self.fcs):
            if i == 0:
                input = torch.relu(fc(input + self.embedding(l)))
            elif i < len(self.fcs)-1:
                input = torch.relu(fc(input))
            else:
                input = torch.sigmoid(fc(input))
        return input

In [8]:
model = basic_attack_model([184, 8, 4, 1], 44, 1)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.to(device)

basic_attack_model(
  (fcs): ModuleList(
    (0): Linear(in_features=184, out_features=8, bias=True)
    (1): Linear(in_features=8, out_features=4, bias=True)
    (2): Linear(in_features=4, out_features=1, bias=True)
  )
  (embedding): Embedding(44, 184)
)

In [9]:
epochs = 100
batch_size = 256
Xtensor_list_pub, ytensor_list_pub = torch.load("out/data/basic_attack_pub_tensors.pt")
data_list_pub = [tensor[0, 0:184].view(1,184) for tensor in Xtensor_list_pub]
label_list_pub = [tensor[0, 184].item() for tensor in Xtensor_list_pub]
print(data_list_pub[0].shape, ytensor_list_pub[0].shape, len(Xtensor_list_pub))
for epoch in range(epochs):
    model.train()
    for i in range(0, len(Xtensor_list_pub), batch_size):
        databatch = torch.stack(data_list_pub[i:i+batch_size]).to(device)#.mean(dim = -1)
        labelbatch = torch.tensor([label_list_pub[i:i+batch_size]],dtype=torch.long).T.to(device)
        ybatch = torch.stack(ytensor_list_pub[i:i+batch_size]).to(device)
        optimizer.zero_grad()
        output = model(databatch, labelbatch)
        loss = criterion(output, ybatch.unsqueeze(1).type(torch.float32))
        loss.backward()
        optimizer.step()
    net_loss = 0
    net_correct = 0
    total = 0
    for i in range(0, len(Xtensor_list_pub), batch_size):
        databatch = torch.stack(data_list_pub[i:i+batch_size]).to(device)#.mean(dim = -1)
        labelbatch = torch.tensor([label_list_pub[i:i+batch_size]],dtype=torch.long).T.to(device)
        ybatch = torch.stack(ytensor_list_pub[i:i+batch_size]).to(device)
        optimizer.zero_grad()
        output = model(databatch, labelbatch)
        one_zero = (output > 0.5).type(torch.int32)
        correct = (one_zero == ybatch.unsqueeze(1)).sum().item()
        net_correct += correct
        total += len(ybatch)
        loss = criterion(output, ybatch.unsqueeze(1).type(torch.float32))
        net_loss += loss.item()
    print(epoch, net_loss * batch_size/len(Xtensor_list_pub), net_correct/total)

  Xtensor_list_pub, ytensor_list_pub = torch.load("out/data/basic_attack_pub_tensors.pt")


torch.Size([1, 184]) torch.Size([1]) 20000
0 0.7010613586425781 0.49825
1 0.7008734786987305 0.50215
2 0.7008628219604492 0.50275
3 0.700864353942871 0.50365
4 0.7008617736816406 0.5029
5 0.7008609283447266 0.504
6 0.700854483795166 0.50355
7 0.7008468467712402 0.50375
8 0.7008835494995117 0.50085
9 0.7008701141357422 0.50255
10 0.7008500961303711 0.5048
11 0.7008174995422364 0.50565
12 0.7008672866821289 0.5045
13 0.7008572608947754 0.50345
14 0.7008422225952149 0.50335
15 0.7008070198059082 0.5058
16 0.7008235160827637 0.50425
17 0.7008393608093262 0.504
18 0.7008278900146484 0.50485
19 0.7008167167663574 0.5064
20 0.7008029945373535 0.5064
21 0.7007360504150391 0.5063
22 0.7007863601684571 0.50755
23 0.7007619285583496 0.50825
24 0.7007449836730957 0.50725
25 0.7007748672485351 0.50455
26 0.7007443046569825 0.5085
27 0.7007644165039062 0.5031
28 0.7007228485107422 0.5072
29 0.7005765808105469 0.51175
30 0.700642000579834 0.50985
31 0.700639826965332 0.50845
32 0.7004924545288086 0.5

In [10]:
torch.save(model.state_dict(), "out/models/temp.pt")

In [11]:
ids = []
logit_score = []
loglogit_score = []
Xtensor_list_priv= torch.load("out/data/basic_attack_priv_tensors.pt")
data_list_priv = [tensor[0, 0:184].view(1,184) for tensor in Xtensor_list_priv]
label_list_priv = [tensor[0, 184].item() for tensor in Xtensor_list_priv]
model.eval()
for i in range(len(Xtensor_list_priv)):
    databatch = data_list_priv[i].unsqueeze(0).to(device)
    labelbatch = torch.tensor([label_list_priv[i]],dtype=torch.long).T.unsqueeze(0).to(device)
    output = model(databatch, labelbatch).to(device)
    logit = output.item()/(1-output.item())
    loglogit = np.log(logit)
    ids.append(priv_data[i][0])
    logit_score.append(logit)
    loglogit_score.append(loglogit)
    print(priv_data[i][0], output.item(), logit, loglogit)

  Xtensor_list_priv= torch.load("out/data/basic_attack_priv_tensors.pt")
  labelbatch = torch.tensor([label_list_priv[i]],dtype=torch.long).T.unsqueeze(0).to(device)


55061 0.5499773025512695 1.2221101417088558 0.20057898901903495
67669 0.5621360540390015 1.2838144341965805 0.2498356731697554
228925 0.5425598621368408 1.1860783897786091 0.17065239432603999
4248 0.4835769832134247 0.9363970378827536 -0.06571570667372331
109584 0.5549829006195068 1.2471046649490474 0.22082459657508516
117235 0.5518710613250732 1.2315006099737762 0.20823343388793597
71898 0.49150601029396057 0.9665915826814401 -0.033979227780557314
109497 0.5181673169136047 1.075409234580076 0.07270127246841475
117936 0.5308818221092224 1.1316590299189493 0.12368472408910225
116979 0.509926438331604 1.0405099932255502 0.03971097105872336
252165 0.5251113176345825 1.105756648103312 0.10052985005056495
44279 0.4704630374908447 0.8884423011032224 -0.11828557316608308
97390 0.5164665579795837 1.0681092828275918 0.06589006004811959
60210 0.5293799042701721 1.1248561399598136 0.11765515188832615
27707 0.4704630374908447 0.8884423011032224 -0.11828557316608308
214933 0.4704630374908447 0.8884

In [12]:
len(logit_score)

20000

In [13]:
df = pd.DataFrame(
    {
        "ids": ids,
        "score": logit_score,
    }
)
df.to_csv("test.csv", index=None)

In [14]:
response = requests.post("http://35.239.75.232:9090/mia", files={"file": open("test.csv", "rb")}, headers={"token": "13301858"})
print(response.json())

{'TPR@FPR=0.05': 0.043, 'AUC': 0.5112066111111111}
