In [1]:
import random
import torch
import pandas as pd
import numpy as np
from typing import Tuple
from torch.utils.data import Dataset, Subset

In [2]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label
    def __len__(self):
        return len(self.ids)

In [3]:
class MembershipDataset(TaskDataset):
    def __init__(self, transform=None):
        super().__init__(transform)
        self.membership = []
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int, int]:
        id_, img, label = super().__getitem__(index)
        return id_, img, label, self.membership[index]

In [4]:
def create_class_balanced_splits(dataset, num_classes, split_size, alternate_odd=True):
    class_indices = {i: [] for i in range(num_classes)}
    for idx in range(len(dataset)):
        _, _, label, _ = dataset[idx]
        class_indices[label].append(idx)
    indices_A, indices_B = [], []
    odd_flag = True
    for label, indices in class_indices.items():
        random.shuffle(indices)
        mid = len(indices) // 2
        if len(indices) % 2 == 1 and alternate_odd:
            if odd_flag:
                indices_A.extend(indices[:mid + 1])
                indices_B.extend(indices[mid + 1:])
            else:
                indices_A.extend(indices[:mid])
                indices_B.extend(indices[mid:])
            odd_flag = not odd_flag
        else:
            indices_A.extend(indices[:mid])
            indices_B.extend(indices[mid:])
    random.shuffle(indices_A)
    random.shuffle(indices_B)
    return Subset(dataset, indices_A[:split_size]), Subset(dataset, indices_B[:split_size])

In [5]:
def update_dataframe(dataset, dataframe, splits, num):
    dataframe[f"split_{num}"] = ""
    membership = np.array(["" for _ in range(len(dataset))], dtype=object)
    membership[splits[0]] = "A"
    membership[splits[1]] = "B"
    dataframe[f"split_{num}"] = membership
    return dataframe

In [6]:
def process_and_save(dataset, num_classes=44, num_splits=8, output_path=".", name="priv"):
    data = {
        "id": [dataset.ids[i] for i in range(len(dataset))],
        "label": [dataset.labels[i] for i in range(len(dataset))]
    }
    dataframe = pd.DataFrame(data)
    for i in range(1, num_splits + 1):
        print(f"Split {i} for {name} started")
        dataset_A, dataset_B = create_class_balanced_splits(dataset, num_classes, len(dataset) // 2)
        torch.save(dataset_A, f"{output_path}/split_{i}_A_{name}.pt")
        torch.save(dataset_B, f"{output_path}/split_{i}_B_{name}.pt")
        dataframe = update_dataframe(dataset, dataframe, (dataset_A.indices, dataset_B.indices), i)
    dataframe.to_csv(f"{output_path}/membership_splits_{name}.csv", index=False)

In [7]:
pub_dataset = torch.load("out/data/pub.pt")
process_and_save(pub_dataset, num_classes=44, num_splits=32, output_path="./outputs", name="pub")

  pub_dataset = torch.load("out/data/pub.pt")


Split 1 for pub started
Split 2 for pub started
Split 3 for pub started
Split 4 for pub started
Split 5 for pub started
Split 6 for pub started
Split 7 for pub started
Split 8 for pub started
Split 9 for pub started
Split 10 for pub started
Split 11 for pub started
Split 12 for pub started
Split 13 for pub started
Split 14 for pub started
Split 15 for pub started
Split 16 for pub started
Split 17 for pub started
Split 18 for pub started
Split 19 for pub started
Split 20 for pub started
Split 21 for pub started
Split 22 for pub started
Split 23 for pub started
Split 24 for pub started
Split 25 for pub started
Split 26 for pub started
Split 27 for pub started
Split 28 for pub started
Split 29 for pub started
Split 30 for pub started
Split 31 for pub started
Split 32 for pub started


In [8]:
priv_dataset = torch.load("out/data/priv.pt")
process_and_save(priv_dataset, num_classes=44, num_splits=32, output_path="./outputs", name="priv")

  priv_dataset = torch.load("out/data/priv.pt")


Split 1 for priv started
Split 2 for priv started
Split 3 for priv started
Split 4 for priv started
Split 5 for priv started
Split 6 for priv started
Split 7 for priv started
Split 8 for priv started
Split 9 for priv started
Split 10 for priv started
Split 11 for priv started
Split 12 for priv started
Split 13 for priv started
Split 14 for priv started
Split 15 for priv started
Split 16 for priv started
Split 17 for priv started
Split 18 for priv started
Split 19 for priv started
Split 20 for priv started
Split 21 for priv started
Split 22 for priv started
Split 23 for priv started
Split 24 for priv started
Split 25 for priv started
Split 26 for priv started
Split 27 for priv started
Split 28 for priv started
Split 29 for priv started
Split 30 for priv started
Split 31 for priv started
Split 32 for priv started


In [9]:
def dataset_creation(pub_dataset, priv_dataset, num):
    priv_A = torch.load(f'outputs/split_{num}_A_priv.pt')
    priv_B = torch.load(f'outputs/split_{num}_B_priv.pt')
    pub_A = torch.load(f'outputs/split_{num}_A_pub.pt')
    pub_B = torch.load(f'outputs/split_{num}_B_pub.pt')
    dataset_A = MembershipDataset()
    dataset_B = MembershipDataset()
    for index in priv_A.indices:
        dataset_A.ids.append(priv_dataset.ids[index])
        dataset_A.imgs.append(priv_dataset.imgs[index])
        dataset_A.labels.append(priv_dataset.labels[index])
        dataset_A.membership.append(priv_dataset.membership[index])
    for index in pub_A.indices:
        dataset_A.ids.append(pub_dataset.ids[index])
        dataset_A.imgs.append(pub_dataset.imgs[index])
        dataset_A.labels.append(pub_dataset.labels[index])
        dataset_A.membership.append(pub_dataset.membership[index])
    for index in priv_B.indices:
        dataset_B.ids.append(priv_dataset.ids[index])
        dataset_B.imgs.append(priv_dataset.imgs[index])
        dataset_B.labels.append(priv_dataset.labels[index])
        dataset_B.membership.append(priv_dataset.membership[index])
    for index in pub_B.indices:
        dataset_B.ids.append(pub_dataset.ids[index])
        dataset_B.imgs.append(pub_dataset.imgs[index])
        dataset_B.labels.append(pub_dataset.labels[index])
        dataset_B.membership.append(pub_dataset.membership[index])
    torch.save(dataset_A, f'pretrain/split_{num}_A.pt')
    torch.save(dataset_B, f'pretrain/split_{num}_B.pt')

In [10]:
def all_datasets(pub_dataset, priv_dataset, num_splits):
    for i in range(1, num_splits + 1):
        dataset_creation(pub_dataset, priv_dataset, i)
        print(f"Split {i} completed")

In [11]:
all_datasets(pub_dataset, priv_dataset, num_splits=32)

  priv_A = torch.load(f'outputs/split_{num}_A_priv.pt')
  priv_B = torch.load(f'outputs/split_{num}_B_priv.pt')
  pub_A = torch.load(f'outputs/split_{num}_A_pub.pt')
  pub_B = torch.load(f'outputs/split_{num}_B_pub.pt')


Split 1 completed
Split 2 completed
Split 3 completed
Split 4 completed
Split 5 completed
Split 6 completed
Split 7 completed
Split 8 completed
Split 9 completed
Split 10 completed
Split 11 completed
Split 12 completed
Split 13 completed
Split 14 completed
Split 15 completed
Split 16 completed
Split 17 completed
Split 18 completed
Split 19 completed
Split 20 completed
Split 21 completed
Split 22 completed
Split 23 completed
Split 24 completed
Split 25 completed
Split 26 completed
Split 27 completed
Split 28 completed
Split 29 completed
Split 30 completed
Split 31 completed
Split 32 completed
