In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision

from torchvision import datasets, transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Subset, random_split

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt


In [None]:
def fft2(x):  return torch.fft.fft2(x, dim=(-2, -1))
def ifft2(x): return torch.fft.ifft2(x, dim=(-2, -1))

def project_unit_magnitude(s, eps=1e-8):
    S = fft2(s)
    S = S / (torch.abs(S) + eps)
    return ifft2(S).real

def sample_secret_like(x):
    s = torch.randn_like(x)
    return project_unit_magnitude(s)

def hrr_bind(x, s):
    # x âŠ› s = F^{-1}(F(x) * F(s))
    return ifft2(fft2(x) * fft2(s)).real


In [None]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

full_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

target_train, target_test, population_data = random_split(full_train, [20000, 20000, 10000])

train_loader = DataLoader(target_train, batch_size=64, shuffle=True)


100.0%


In [None]:
def make_cifar_resnet18():
    m = resnet18(num_classes=10)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    return m


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def train_hrr_model(loader, epochs=5, lr=1e-3):
    model = make_cifar_resnet18().to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    model.train()
    for ep in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            s = sample_secret_like(x)      # (B,C,H,W)
            x_enc = hrr_bind(x, s)         # (B,C,H,W)

            opt.zero_grad()
            out = model(x_enc)
            loss = loss_fn(out, y)
            loss.backward()
            opt.step()

        print(f"epoch {ep+1}/{epochs} done")
    return model

print("Training HRR target model...")
target_model = train_hrr_model(train_loader, epochs=5)


Training HRR target model...
epoch 1/5 done
epoch 2/5 done
epoch 3/5 done
epoch 4/5 done
epoch 5/5 done


In [None]:
class HRRWrapper(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

    def forward(self, x):
        s = sample_secret_like(x)
        x_enc = hrr_bind(x, s)
        return self.base(x_enc)

def get_rmia_score_multi(tar_model, ref_models, known_img, known_label, population_subset, gamma=1.0, a=0.3):
    tar_model.eval()
    for rm in ref_models: rm.eval()

    with torch.no_grad():
        prob_x_target = torch.softmax(tar_model(known_img.unsqueeze(0)), dim=1)[0, known_label].item()
        all_ref_probs_x = [
            torch.softmax(rm(known_img.unsqueeze(0)), dim=1)[0, known_label].item()
            for rm in ref_models
        ]
        prob_x_out = float(np.mean(all_ref_probs_x))

        pr_x = 0.5 * ((1 + a) * prob_x_out + (1 - a))
        ratio_x = prob_x_target / (pr_x + 1e-10)

        count = 0
        for z_img, z_label in population_subset:
            prob_z_target = torch.softmax(tar_model(z_img.unsqueeze(0)), dim=1)[0, z_label].item()
            all_ref_probs_z = [
                torch.softmax(rm(z_img.unsqueeze(0)), dim=1)[0, z_label].item()
                for rm in ref_models
            ]
            prob_z_out = float(np.mean(all_ref_probs_z))
            ratio_z = prob_z_target / (prob_z_out + 1e-10)

            if (ratio_x / (ratio_z + 1e-10)) > gamma:
                count += 1

    return count / len(population_subset)


In [None]:
num_ref_models = 2
ref_models = []

pop_indices = np.arange(len(population_data))

print("Training HRR reference models...")
for i in range(num_ref_models):
    np.random.shuffle(pop_indices)
    subset = Subset(population_data, pop_indices[:len(pop_indices)//2])
    loader = DataLoader(subset, batch_size=64, shuffle=True)

    m = train_hrr_model(loader, epochs=5)
    ref_models.append(m)

target_wrapped = HRRWrapper(target_model).to(device)
refs_wrapped = [HRRWrapper(m).to(device) for m in ref_models]


Training HRR reference models...
epoch 1/5 done
epoch 2/5 done
epoch 3/5 done
epoch 4/5 done
epoch 5/5 done
epoch 1/5 done
epoch 2/5 done
epoch 3/5 done
epoch 4/5 done
epoch 5/5 done


In [None]:
z_samples = [population_data[i] for i in range(100)]

scores, labels = [], []

for i in range(100):
    img_m, lbl_m = target_train[i]
    img_n, lbl_n = target_test[i]

    img_m, img_n = img_m.to(device), img_n.to(device)

    scores.append(get_rmia_score_multi(target_wrapped, refs_wrapped, img_m, lbl_m, z_samples))
    labels.append(1)

    scores.append(get_rmia_score_multi(target_wrapped, refs_wrapped, img_n, lbl_n, z_samples))
    labels.append(0)

fpr, tpr, _ = roc_curve(labels, scores)
roc_auc = auc(fpr, tpr)

print(f"HRR + RMIA AUROC: {roc_auc:.4f}")


HRR + RMIA AUROC: 0.5021


The RMIA attack was evaluated against the HRR-protected model.
The attack achieved an AUROC of 0.5021, which is approximately equal to random guessing (0.5).
This indicates that the HRR defense effectively prevents the attack from distinguishing between training members and non-members.