In [None]:
import torch
import copy
from typing import Tuple
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
!rm -rf /kaggle/working/*

In [None]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label
    def __len__(self):
        return len(self.ids)

In [None]:
class MembershipDataset(TaskDataset):
    def __init__(self, transform=None):
        super().__init__(transform)
        self.membership = []
    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int, int]:
        id_, img, label = super().__getitem__(index)
        if self.membership[index] is None:
            return id_, img, label, -1
        return id_, img, label, self.membership[index]

In [None]:
transform_00 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
])
transform_01 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
])
transform_10 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])
transform_11 = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=1),  # Apply horizontal flip
    transforms.RandomVerticalFlip(p=1),    # Apply vertical flip
])
transform_r = transforms.Compose([
    transforms.Normalize(mean=[0.2980, 0.2962, 0.2987], std=[0.2886, 0.2875, 0.2889]),  # Normalize with mean and std
    transforms.RandomHorizontalFlip(p=0.5),  # Apply horizontal flip
    transforms.RandomVerticalFlip(p=0.5),    # Apply vertical flip
])

In [None]:
#ckpt = torch.load('/kaggle/input/sprintmodel/attack_model.pt', map_location=device)
model_A = resnet18(weights=None)
model_A.fc = torch.nn.Linear(512, 44)
#model_A.load_state_dict(ckpt)
model_A.to(device)
model_A.train()
model_B = resnet18(weights=None)
model_B.fc = torch.nn.Linear(512, 44)
#model_B.load_state_dict(ckpt)
model_B.to(device)
model_B.train()

In [None]:
#priv_dataset = torch.load('out/data/priv.pt')
#pub_dataset = torch.load('out/data/pub.pt')
data_num = 31
dataset_A = torch.load(f'/kaggle/input/sprintml-dataset-1/pretrain_kaggle/split_{data_num}_A.pt')
dataset_B = torch.load(f'/kaggle/input/sprintml-dataset-1/pretrain_kaggle/split_{data_num}_B.pt')
dataloader_A = DataLoader(dataset_A, batch_size=128, shuffle=True)
dataloader_B = DataLoader(dataset_B, batch_size=128, shuffle=True)
dataset_AA = torch.load(f'/kaggle/input/sprintml-dataset-1/pretrain_kaggle/split_{data_num}_A.pt')
dataset_BB = torch.load(f'/kaggle/input/sprintml-dataset-1/pretrain_kaggle/split_{data_num}_B.pt')

In [None]:
optimizer_A = torch.optim.Adam(model_A.parameters(), lr=1e-3)
criterion_A = torch.nn.CrossEntropyLoss()
optimizer_B = torch.optim.Adam(model_B.parameters(), lr=1e-3)
criterion_B = torch.nn.CrossEntropyLoss()

In [None]:
def train_model(model, dataloader_1, dataset_1, dataset_2, optimizer, criterion, epochs, batch_size, path):
    save_dict = dict()
    for epoch in range(epochs):
        model.train()
        train_samples = 0
        for ids, imgs, labels, memberships in dataloader_1:
            #fimgs = []
            #fimgs.extend([transform_00(img) for img in imgs])
            #fimgs.extend([transform_01(img) for img in imgs])
            #fimgs.extend([transform_10(img) for img in imgs])
            #fimgs.extend([transform_11(img) for img in imgs])
            imgs = [transform_r(img) for img in imgs]
            train_samples += len(imgs)
            imgs = torch.stack(imgs).to(device)
            labels = labels.to(device)
            preds = model(imgs)
            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #print(f'Epoch {epoch+1}/{epochs}, batch {i//batch_size+1}/{-(-len(dataset_1) // batch_size)}, loss {loss.item()}')
        model.eval()
        total_correct_in = 0
        total_correct_out = 0
        total_samples_in = 0
        total_samples_out = 0
        for i in range(0, len(dataset_1), batch_size):
            batch = dataset_1[i:i+batch_size]
            ids, imgs, labels, membership  = batch
            imgs = torch.stack([transform_00(img) for img in imgs]).to(device)
            labels = torch.tensor(labels).to(device)
            preds = model(imgs)
            correct = torch.where(preds.argmax(dim=1) == labels, 1, 0).sum().item()
            total_correct_in += correct
            total_samples_in += len(labels)
            #print(f'Accuracy: {correct/len(labels)}')
        for i in range(0, len(dataset_2), batch_size):
            batch = dataset_2[i:i+batch_size]
            ids, imgs, labels, membership  = batch
            imgs = torch.stack([transform_00(img) for img in imgs]).to(device)
            labels = torch.tensor(labels).to(device)
            preds = model(imgs)
            correct = torch.where(preds.argmax(dim=1) == labels, 1, 0).sum().item()
            total_correct_out += correct
            total_samples_out += len(labels)
            #print(f'Accuracy: {correct/len(labels)}')
        acc_in = total_correct_in / total_samples_in
        acc_out = total_correct_out / total_samples_out
        #print(train_samples, total_samples_in, total_samples_out)
        print(f'Epoch {epoch+1}/{epochs}, accuracy in: {acc_in}, accuracy out: {acc_out}')
        save_dict[epoch+1] = {'in_accuracy' : acc_in, 'out_accuracy' : acc_out, 'state_dict' : copy.deepcopy(model.state_dict())}
    torch.save(save_dict, path)

In [None]:
train_model(model_A, dataloader_A, dataset_AA, dataset_BB, optimizer_A, criterion_A, 40, 256, f'/kaggle/working/split_{data_num}_A_output.pt')
train_model(model_B, dataloader_B, dataset_BB, dataset_AA, optimizer_B, criterion_B, 40, 256, f'/kaggle/working/split_{data_num}_B_output.pt')