In [1]:
import torch
import torchvision.transforms as transforms
from typing import Tuple
from torch.utils.data import Dataset
from torchvision.models import resnet18

In [2]:
main_model = resnet18(pretrained=False)
main_model.fc = torch.nn.Linear(512, 44)
ckpt = torch.load("out/models/attack_model.pt", map_location="cpu")
main_model.load_state_dict(ckpt)
main_model.eval()

  ckpt = torch.load("out/models/attack_model.pt", map_location="cpu")


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
transform_00 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
])
transform_01 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
])
transform_10 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])
transform_11 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])

In [4]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label
    def __len__(self):
        return len(self.ids)

In [5]:
class MembershipDataset(TaskDataset):
    def __init__(self, transform=None):
        super().__init__(transform)
        self.membership = []
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int, int]:
        id_, img, label = super().__getitem__(index)
        return id_, img, label, self.membership[index]

In [6]:
priv_data = torch.load("out/data/priv.pt")

  priv_data = torch.load("out/data/priv.pt")


In [7]:
def output_to_tensor(output, label):
    output = output.unsqueeze(0)
    confidence = torch.nn.functional.softmax(output, dim=1)[0, label]
    #print(confidence.shape)
    entropy = (-torch.nn.functional.softmax(output, dim=1) * torch.nn.functional.log_softmax(output, dim=1)).sum()
    #mixed_entropy = -(torch.nn.functional.softmax(output, dim=1) * torch.nn.functional.log_softmax(torch.ones_like(output) - output, dim=1)).sum().item() + confidence*torch.log(1-confidence) - (1-confidence)*torch.log(confidence)
    tensor = torch.cat((output, confidence.view(1,1), entropy.view(1,1)), dim=1)
    return tensor

In [8]:
Xtensor_list_priv = []
for i in range(len(priv_data)):
    id_, img, label, membership = priv_data[i]
    img_00 = transform_00(img).unsqueeze(0)
    img_01 = transform_01(img).unsqueeze(0)
    img_10 = transform_10(img).unsqueeze(0)
    img_11 = transform_11(img).unsqueeze(0)
    imgs = torch.cat((img_00, img_01, img_10, img_11), dim=0)
    outputs = main_model(imgs)
    final_tensor = torch.cat([output_to_tensor(output, label) for output in outputs] + [torch.tensor([label], dtype=torch.long).view(1,1)], dim=1)
    Xtensor_list_priv.append(final_tensor)
    if (i+1)%1000 == 0:
        print(f"Processed {i+1} points")

Processed 1000 points
Processed 2000 points
Processed 3000 points
Processed 4000 points
Processed 5000 points
Processed 6000 points
Processed 7000 points
Processed 8000 points
Processed 9000 points
Processed 10000 points
Processed 11000 points
Processed 12000 points
Processed 13000 points
Processed 14000 points
Processed 15000 points
Processed 16000 points
Processed 17000 points
Processed 18000 points
Processed 19000 points
Processed 20000 points


In [9]:
torch.save(Xtensor_list_priv, "out/data/basic_attack_priv_tensors.pt")