In [1]:
import torch
import pandas as pd
import numpy as np
from typing import Tuple
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
device = torch.device('cpu' if torch.backends.mps.is_available() else 'cpu')

In [2]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label
    def __len__(self):
        return len(self.ids)

In [3]:
class MembershipDataset(TaskDataset):
    def __init__(self, transform=None):
        super().__init__(transform)
        self.membership = []
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int, int]:
        id_, img, label = super().__getitem__(index)
        return id_, img, label, self.membership[index]

In [4]:
transform_00 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
])
transform_01 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
])
transform_10 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])
transform_11 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])
transform_r = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=0.5),  # Apply horizontal flip
    transforms.RandomVerticalFlip(p=0.5),    # Apply vertical flip
])

In [5]:
priv_dataset = torch.load('out/data/priv.pt')
pub_dataset = torch.load('out/data/pub.pt')

  priv_dataset = torch.load('out/data/priv.pt')
  pub_dataset = torch.load('out/data/pub.pt')


In [6]:
priv_loader = DataLoader(priv_dataset, batch_size=16, shuffle=True)
pub_loader = DataLoader(pub_dataset, batch_size=16, shuffle=True)

In [7]:
priv_membership = dict()
for i in range(1,33):
    dataset_A = torch.load(f"pretrain_kaggle/split_{i}_A.pt")
    dataset_B = torch.load(f"pretrain_kaggle/split_{i}_B.pt")
    membership_i = []
    for id_ in priv_dataset.ids:
        if (i == 1):
            priv_membership[id_] = dict()
        if id_ in dataset_A.ids:
            priv_membership[id_][i] = 'A'
        elif id_ in dataset_B.ids:
            priv_membership[id_][i] = 'B'
pub_membership = dict()
for i in range(1,33):
    dataset_A = torch.load(f"pretrain_kaggle/split_{i}_A.pt")
    dataset_B = torch.load(f"pretrain_kaggle/split_{i}_B.pt")
    membership_i = []
    for id_ in pub_dataset.ids:
        if (i == 1):
            pub_membership[id_] = dict()
        if id_ in dataset_A.ids:
            pub_membership[id_][i] = 'A'
        elif id_ in dataset_B.ids:
            pub_membership[id_][i] = 'B'

  dataset_A = torch.load(f"pretrain_kaggle/split_{i}_A.pt")
  dataset_B = torch.load(f"pretrain_kaggle/split_{i}_B.pt")
  dataset_A = torch.load(f"pretrain_kaggle/split_{i}_A.pt")
  dataset_B = torch.load(f"pretrain_kaggle/split_{i}_B.pt")


In [8]:
model_epoch = dict()
model_epoch["1_A"] = 22
model_epoch["1_B"] = 26
model_epoch["2_A"] = 25
model_epoch["2_B"] = 26
model_epoch["3_A"] = 22
model_epoch["3_B"] = 26
model_epoch["4_A"] = 21
model_epoch["4_B"] = 25
model_epoch["5_A"] = 27
model_epoch["5_B"] = 25
model_epoch["6_A"] = 26
model_epoch["6_B"] = 25
model_epoch["7_A"] = 26
model_epoch["7_B"] = 28
model_epoch["8_A"] = 27
model_epoch["8_B"] = 21
model_epoch["9_A"] = 26
model_epoch["9_B"] = 24
model_epoch["10_A"] = 27
model_epoch["10_B"] = 25
model_epoch["11_A"] = 24
model_epoch["11_B"] = 25
model_epoch["12_A"] = 25
model_epoch["12_B"] = 23
model_epoch["13_A"] = 25
model_epoch["13_B"] = 23
model_epoch["14_A"] = 27
model_epoch["14_B"] = 22
model_epoch["15_A"] = 17
model_epoch["15_B"] = 22
model_epoch["16_A"] = 28
model_epoch["16_B"] = 25
model_epoch["17_A"] = 26
model_epoch["17_B"] = 24
model_epoch["18_A"] = 26
model_epoch["18_B"] = 27
model_epoch["19_A"] = 25
model_epoch["19_B"] = 20
model_epoch["20_A"] = 23
model_epoch["20_B"] = 26
model_epoch["21_A"] = 26
model_epoch["21_B"] = 25
model_epoch["22_A"] = 22
model_epoch["22_B"] = 29
model_epoch["23_A"] = 25
model_epoch["23_B"] = 25
model_epoch["24_A"] = 25
model_epoch["24_B"] = 25
model_epoch["25_A"] = 25
model_epoch["25_B"] = 23
model_epoch["26_A"] = 22
model_epoch["26_B"] = 29
model_epoch["27_A"] = 26
model_epoch["27_B"] = 25
model_epoch["28_A"] = 22
model_epoch["28_B"] = 20
model_epoch["29_A"] = 24
model_epoch["29_B"] = 26
model_epoch["30_A"] = 25
model_epoch["30_B"] = 24
model_epoch["31_A"] = 24
model_epoch["31_B"] = 26
model_epoch["32_A"] = 25
model_epoch["32_B"] = 20

In [9]:
model_map = dict()
for i in range(1,33):
    model_A = resnet18()
    model_A.fc = torch.nn.Linear(512, 44)
    A_data = torch.load(f"splits/split_{i}_A_output.pt", map_location=device)
    ckpt_A = A_data[model_epoch[f"{i}_A"]]["state_dict"]
    model_A.load_state_dict(ckpt_A)
    model_A.eval()
    model_A.to(device)
    model_map[f"{i}_A"] = model_A
    model_B = resnet18()
    model_B.fc = torch.nn.Linear(512, 44)
    B_data = torch.load(f"splits/split_{i}_B_output.pt", map_location=device)
    ckpt_B = B_data[model_epoch[f"{i}_B"]]["state_dict"]
    model_B.load_state_dict(ckpt_B)
    model_B.eval()
    model_B.to(device)
    model_map[f"{i}_B"] = model_B

  A_data = torch.load(f"splits/split_{i}_A_output.pt", map_location=device)
  B_data = torch.load(f"splits/split_{i}_B_output.pt", map_location=device)


In [10]:
priv_in_scores = dict()
priv_out_scores = dict()
for key in model_epoch.keys():
    model = model_map[key]
    processed = 0
    for id_, img, label in zip(priv_dataset.ids, priv_dataset.imgs, priv_dataset.labels):
        img_00 = transform_00(img).unsqueeze(0)
        img_01 = transform_01(img).unsqueeze(0)
        img_10 = transform_10(img).unsqueeze(0)
        img_11 = transform_11(img).unsqueeze(0)
        imgs_ = torch.cat((img_00, img_01, img_10, img_11), dim=0)
        outputs = model(imgs_)
        raw_scores = torch.nn.functional.softmax(outputs, dim=1)
        #print(raw_scores.shape)
        confidence_score = raw_scores[:, label]
        #print(confidence_score.shape)
        logit_score = torch.log(confidence_score/(1-confidence_score))
        #print(logit_score.shape)
        #print(logit_score)
        if priv_membership[id_][int(key[:-2])]==key[-1]:
            if id_ not in priv_in_scores:
                priv_in_scores[id_] = list()
            priv_in_scores[id_].extend(logit_score.detach().cpu().numpy())
        else:
            if id_ not in priv_out_scores:
                priv_out_scores[id_] = list()
            priv_out_scores[id_].extend(logit_score.detach().cpu().numpy())
        processed += 1
        if processed % 5000 == 0:
            print(f"Processed {processed} images for model {key}")
    print(f"Processed model {key}")

Processed 5000 images for model 1_A
Processed 10000 images for model 1_A
Processed 15000 images for model 1_A
Processed 20000 images for model 1_A
Processed model 1_A
Processed 5000 images for model 1_B
Processed 10000 images for model 1_B
Processed 15000 images for model 1_B
Processed 20000 images for model 1_B
Processed model 1_B
Processed 5000 images for model 2_A
Processed 10000 images for model 2_A
Processed 15000 images for model 2_A
Processed 20000 images for model 2_A
Processed model 2_A
Processed 5000 images for model 2_B
Processed 10000 images for model 2_B
Processed 15000 images for model 2_B
Processed 20000 images for model 2_B
Processed model 2_B
Processed 5000 images for model 3_A
Processed 10000 images for model 3_A
Processed 15000 images for model 3_A
Processed 20000 images for model 3_A
Processed model 3_A
Processed 5000 images for model 3_B
Processed 10000 images for model 3_B
Processed 15000 images for model 3_B
Processed 20000 images for model 3_B
Processed model 3_

In [11]:
torch.save(priv_in_scores, "splits/in_scores.pt")
torch.save(priv_out_scores, "splits/out_scores.pt")