In [None]:
import os
import argparse
import itertools
import numpy as np
from tqdm import tqdm

import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.utils.data import TensorDataset

from utils import *
from dataset_loader import *

from unet import Unet
from diffusion import GaussianDiffusion
from embedding import ConditionalEmbedding
from scheduler import GradualWarmupScheduler
from mine.mine import Mine, T

from tensorflow.keras.models import load_model
from sklearn.metrics import mean_squared_error

In [None]:
trial = 1

In [None]:
class Args:
    def __init__(self):
        self.gpu_id = 0
        self.batch_size = 128
        self.num_public_attr = 4 # number of public attribute classes (i.e., gender)
        self.num_private_attr = 2 # number of private attribute classes (i.e., gender)
        self.dataset = 'motion' # 'mobi', 'motion', 'wifi'
        self.private = 'gender' # 'weight', 'gender', 'height'
        self.epochs = 90
        self.train_surrogate = False
        self.seed = 102
        self.verbose = 1 # for Keras-based evaluation models

args = Args()
# args = args_parser()

if args.dataset == 'mobi':
    if args.private == 'weight':
        args.num_private_attr = 3
    elif args.private == 'gender':
        args.num_private_attr = 2
elif args.dataset == 'motion':
    args.num_private_attr = 2
elif args.dataset == 'wifi':
    args.batch_size = 8
    args.num_private_attr = 2

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
if args.gpu_id >= 0:
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    cuda_id = "cuda:" + str(0)

device = torch.device(cuda_id if torch.cuda.is_available() else "cpu")
if (torch.cuda.is_available()):
    torch.cuda.set_device(cuda_id)
    print("Current GPU ID:", torch.cuda.current_device())

In [None]:
class DiffusionArgs:
    def __init__(self):
        self.inch = 1 # input channels for Unet model
        self.modch = 256 # model channels for Unet model
        self.T = 1000  # timesteps for Unet model
        self.outch = 1 # output channels for Unet model
        self.chmul = [1,2] # [1,2,2,2] # architecture parameters training Unet model
        self.numres = 2 # number of resblocks for each block in Unet model
        self.cdim = 60 # dimension of conditional embedding
        self.useconv = True # whether use convlution in downsample
        self.droprate = 0.1 # dropout rate for model
        self.dtype = torch.float32
        self.lr = 2e-4 # learning rate
        self.w1 = 1.8 # hyperparameters for classifier-free guidance strength
        self.v = 0.1 # hyperparameters for the variance of posterior distribution
        self.epoch = 80 # epochs for training
        self.multiplier = 2 # multiplier for warmup
        self.threshold = 0.3 # threshold for classifier-free guidance
        self.interval = 4 # epoch interval between two evaluations
        self.moddir = './models/privdiffuser_'+args.dataset+'_'+str(trial) # model addresses
        # self.samdir = 'privdiffuser_sample' # sample addresses
        self.genbatch = 80 # batch size for sampling process
        # self.clsnum = 4 # 10 # num of label classes
        self.num_steps = 50 # sampling steps for DDIM
        self.eta = 0 # eta for variance during DDIM sampling process
        self.select = 'linear' # selection stragies for DDIM
        self.ddim = True # whether to use ddim
        self.local_rank = -1 # node rank for distributed training
        self.gpu_id = args.gpu_id
        self.w2 = 0.0 # hyperparameters for negative classifier guidance strength
        
diff_args = DiffusionArgs()

if args.dataset == 'wifi':
    diff_args.useconv = False
    diff_args.epoch = 28
elif args.dataset == 'mobi':
    diff_args.useconv = True
    diff_args.epoch = 28
    diff_args.modch=64
elif args.dataset == 'motion':
    diff_args.epoch = 80 
    diff_args.interval = 8

if not os.path.exists(diff_args.moddir):
    os.makedirs(diff_args.moddir)
    print("Mod dir created.")
else:
    print("Mod dir already exist.")

In [None]:
# load dataset
if args.dataset == 'mobi':
    x_train, x_test, activity_train_label, activity_test_label, gender_train_label, gender_test_label, weight_train_label, weight_test_label, user_groups, user_groups_test, id_train, id_test = load_mobiact(args)
elif args.dataset == 'motion':
    x_train, x_test, activity_train_label, activity_test_label, gender_train_label, gender_test_label, user_groups, user_groups_test, id_train, id_test = load_motionsense()
elif args.dataset == 'wifi':
    x_train, x_test, activity_train_label, activity_test_label, weight_train_label, weight_test_label, height_train_label, height_test_label, user_groups, user_groups_test, id_train, id_test = load_wifi(args)

In [None]:
# prepare dataset
x_train_tensor = torch.from_numpy(x_train.astype('float32'))
x_test_tensor = torch.from_numpy(x_test.astype('float32'))

x_train_tensor = torch.permute(x_train_tensor, (0,3,1,2))
x_test_tensor = torch.permute(x_test_tensor, (0,3,1,2))

act_train_tensor = torch.from_numpy(np.argmax(activity_train_label, axis=1))
act_test_tensor = torch.from_numpy(np.argmax(activity_test_label, axis=1))

if args.dataset == 'mobi':
    gen_train_tensor = torch.from_numpy(np.argmax(gender_train_label, axis=1))
    weight_train_tensor = torch.from_numpy(np.argmax(weight_train_label, axis=1))
    train_dataset = TensorDataset(x_train_tensor, act_train_tensor, gen_train_tensor, weight_train_tensor)
elif args.dataset == 'wifi':
    height_train_tensor = torch.from_numpy(np.argmax(height_train_label, axis=1))
    weight_train_tensor = torch.from_numpy(np.argmax(weight_train_label, axis=1))
    train_dataset = TensorDataset(x_train_tensor, act_train_tensor, height_train_tensor, weight_train_tensor)
elif args.dataset == 'motion':
    gen_train_tensor = torch.from_numpy(np.argmax(gender_train_label, axis=1))
    train_dataset = TensorDataset(x_train_tensor, act_train_tensor, gen_train_tensor)
    
data_train_loader = list(torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True))

In [None]:
if args.dataset == 'mobi':
    surrogate_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(activity_train_label))
    surrogate_test_dataset = TensorDataset(torch.permute(torch.from_numpy(x_test.astype('float32')), (0,3,1,2)), torch.from_numpy(activity_test_label), torch.from_numpy(gender_test_label), torch.from_numpy(weight_test_label))
    if args.private == 'gender':
        priv_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(gender_train_label), torch.from_numpy(activity_train_label))
    elif args.private == 'weight':
        priv_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(weight_train_label), torch.from_numpy(activity_train_label))
    else:
        raise ValueError('Private attribute not found:', args.private, 'in dataset:', args.dataset)

elif args.dataset == 'wifi':
    surrogate_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(activity_train_label))
    surrogate_test_dataset = TensorDataset(torch.permute(torch.from_numpy(x_test.astype('float32')), (0,3,1,2)), torch.from_numpy(activity_test_label), torch.from_numpy(height_test_label), torch.from_numpy(weight_test_label))
    if args.private == 'height':
        priv_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(height_train_label), torch.from_numpy(activity_train_label))
    elif args.private == 'weight':
        priv_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(weight_train_label), torch.from_numpy(activity_train_label))
    else: 
        raise ValueError('Private attribute not found:', args.private, 'in dataset:', args.dataset)

elif args.dataset == 'motion':
    surrogate_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(activity_train_label))
    surrogate_test_dataset = TensorDataset(torch.permute(torch.from_numpy(x_test.astype('float32')), (0,3,1,2)), torch.from_numpy(activity_test_label), torch.from_numpy(gender_test_label))
    if args.private == 'gender':
        priv_train_dataset = TensorDataset(torch.permute(torch.from_numpy(x_train.astype('float32')), (0,3,1,2)), torch.from_numpy(gender_train_label), torch.from_numpy(activity_train_label))
    else:
        raise ValueError('Private attribute not found:', args.private, 'in dataset:', args.dataset)

else:
    raise ValueError('Dataset not found:', args.dataset)

surrogate_train_loader = list(torch.utils.data.DataLoader(surrogate_train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True))
surrogate_test_loader = list(torch.utils.data.DataLoader(surrogate_test_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True))
priv_train_loader = list(torch.utils.data.DataLoader(priv_train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True))

## Surrogate Utility Model

In [None]:
class Surrogate_Utility_Classifier(nn.Module):
    def __init__(self, num_classes, z_dim):
        super(Surrogate_Utility_Classifier, self).__init__()

        self.flatten = torch.nn.Flatten()

        if args.dataset == 'mobi':
            self.fc1 = torch.nn.Conv2d(1, 8, kernel_size=2, stride=1, padding=1)
            self.fc2 = torch.nn.Conv2d(8, 16, kernel_size=2, stride=1, padding=1)
            self.fc3 = nn.Linear(16640, 512)
        elif args.dataset == 'wifi':
            self.fc1 = torch.nn.Conv2d(1, 16, kernel_size=2, stride=1, padding=1)
            self.fc2 = torch.nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=1)
            self.fc3 = nn.Linear(241408, 512) # b8
        elif args.dataset == 'motion':
            self.fc1 = torch.nn.Conv2d(1, 64, kernel_size=2, stride=1, padding=1)
            self.fc2 = torch.nn.Conv2d(64, 128, kernel_size=2, stride=1, padding=1)
            self.fc3 = nn.Linear(66560, 512) # 64 128
            
        self.fc4 = nn.Linear(512, 128)
        self.fc5 = nn.Linear(128, z_dim)
        self.fc6 = nn.Linear(z_dim, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.flatten(self.relu(self.fc2(h1)))
        h3 = self.relu(self.fc3(h2))
        h4 = self.relu(self.fc4(h3))
        h5 = self.relu(self.fc5(h4))
        z = torch.nn.functional.normalize(h5, dim=-1)
        h6 = self.fc6(z)
        return h6, z

In [None]:
def train_pub_classifier(model, trainloader, optimizer, epochs, model_path, device):
    print(model_path)
    model.train()
    criterion_base = torch.nn.BCELoss()
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0     
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            logits, feature = model(inputs)
            outputs = torch.nn.functional.softmax(logits, dim=-1)
            loss = criterion_base(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            
            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0
    # Save trained models
    torch.save(model.state_dict(), model_path)
    print('Surrogate Utility Model Training Completed')

In [None]:
def eval_pub_classifier(args, model, testloader, pub_attr, device):  
    # pub_attr: True or False, specify if this aux classifier is for public attribute or private attribute
    correct = 0
    total = 0
    labs_pred = np.array([])
    labs_raw = np.array([])
    model.eval()
    with torch.no_grad():
        for data in testloader:
            if args.dataset == 'mobi':
                images, act_labs, gen_labs, weight_labs = data
                if pub_attr:
                    labels = act_labs.to(device)
                else:
                    if args.private == 'gender':
                        labels = gen_labs.to(device)
                    elif args.private == 'weight':
                        labels = weight_labs.to(device)
            elif args.dataset == 'wifi':
                images, act_labs, height_labs, weight_labs = data
                if pub_attr:
                    labels = act_labs.to(device)
                else:
                    if args.private == 'height':
                        labels = height_labs.to(device)
                    elif args.private == 'weight':
                        labels = weight_labs.to(device)
            elif args.dataset == 'motion':
                images, act_labs, gen_labs = data
                if pub_attr:
                    labels = act_labs.to(device)
                else:
                    labels = gen_labs.to(device)        
    
            images = images.to(device)
            # calculate outputs by running images through the network
            logits, feature = model(images)
            outputs = torch.nn.functional.softmax(logits, dim=-1)
            _, predicted = torch.max(outputs.data, 1)
            labs_pred = np.concatenate((labs_pred, predicted.detach().cpu().numpy()), axis=0 )
            labs_raw = np.concatenate((labs_raw, torch.argmax(labels, dim=1).detach().cpu().numpy()), axis=0)
            total += labels.size(0)
            correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
    print("Dataset:", args.dataset)
    if pub_attr:
        print("Public Attr: Activity")
    else:
        print("Private Attr:", args.private)
    print(f'Test Accuracy: {100 * correct / total} %')
    print_accu_confmat_f1score(Y_true=labs_raw, Y_pred=labs_pred)

In [None]:
surrogate_model = Surrogate_Utility_Classifier(args.num_public_attr, diff_args.cdim)
surrogate_model = surrogate_model.to(device)

optimizer_pub = Adam(surrogate_model.parameters(), lr=2e-4)
PATH_PUB = diff_args.moddir+'/pub_classifier_'+args.dataset+'_'+str(trial)+'.pt'
print(PATH_PUB)

In [None]:
# train_pub = True
train_pub = False

if args.dataset == 'mobi':
    pub_epochs = 15
elif args.dataset == 'wifi':
    pub_epochs = 25
elif args.dataset == 'motion':
    pub_epochs = 80
    
if train_pub:
    train_pub_classifier(surrogate_model, surrogate_train_loader, optimizer_pub, pub_epochs, PATH_PUB, device)
    torch.cuda.empty_cache()

In [None]:
if not train_pub:
    surrogate_model.load_state_dict(torch.load(PATH_PUB, map_location=torch.device('cpu')))
    surrogate_model.to(device)
eval_pub_classifier(args, surrogate_model, surrogate_test_loader, True, device)

## Diffusion Model

In [None]:
# assert diff_args.genbatch % (torch.cuda.device_count() * diff_args.clsnum) == 0 , 'please re-set your genbatch!!!'

# load data
dataloader = data_train_loader

# initialize models
net = Unet(
    in_ch = diff_args.inch,
    mod_ch = diff_args.modch,
    out_ch = diff_args.outch,
    ch_mul = diff_args.chmul,
    num_res_blocks = diff_args.numres,
    cdim = diff_args.cdim,
    use_conv = diff_args.useconv,
    droprate = diff_args.droprate,
    dtype = diff_args.dtype
)

betas = get_named_beta_schedule(num_diffusion_timesteps = diff_args.T)

diffusion = GaussianDiffusion(
    dtype = diff_args.dtype,
    model = net,
    betas = betas,
    w = diff_args.w1,
    v = diff_args.v,
    device = device
)

cemblayer = ConditionalEmbedding(10, diff_args.cdim, diff_args.cdim).to(device)

# load last epoch
lastpath = os.path.join(diff_args.moddir,'last_epoch.pt')
if os.path.exists(lastpath):
    lastepc = torch.load(lastpath)['last_epoch']
    # load checkpoints
    checkpoint = torch.load(os.path.join(diff_args.moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    cemblayer.load_state_dict(checkpoint['cemblayer'])
else:
    lastepc = 0

In [None]:
# optimizer settings
optimizer = torch.optim.AdamW(
    itertools.chain(
        diffusion.model.parameters(),
        cemblayer.parameters()
    ),
    lr = diff_args.lr,
    weight_decay = 1e-4
)

cosineScheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer = optimizer,
    T_max = diff_args.epoch,
    eta_min = 0,
    last_epoch = -1
)

warmUpScheduler = GradualWarmupScheduler(
    optimizer = optimizer,
    multiplier = diff_args.multiplier,
    warm_epoch = diff_args.epoch // 10,
    after_scheduler = cosineScheduler,
    last_epoch = lastepc
)

if lastepc != 0:
    optimizer.load_state_dict(checkpoint['optimizer'])
    warmUpScheduler.load_state_dict(checkpoint['scheduler'])
    
# training
# cnt = torch.cuda.device_count()
cnt = 1

mean_losses = []

for epc in range(lastepc, diff_args.epoch):
    # turn into train mode
    diffusion.model.train()
    cemblayer.train()
    running_loss = 0.0
    num_batch = 0
    with tqdm(dataloader, dynamic_ncols=True, disable=False) as tqdmDataLoader:
        for img, *other in tqdmDataLoader:
            b = img.shape[0]
            optimizer.zero_grad()
            x_0 = img.to(device)
            num_batch += 1

            output, emb = surrogate_model(x_0)
            
            cemb = cemblayer(emb.detach())
            cemb[np.where(np.random.rand(b)<diff_args.threshold)] = 0
            loss = diffusion.trainloss(x_0, cemb = cemb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss": loss.item(),
                    "batch per device":x_0.shape[0],
                    "img shape": x_0.shape[1:],
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                }
            )
    warmUpScheduler.step()
    
    print(f'[{epc + 1}, loss: {running_loss / num_batch:.3f}')
    mean_losses.append(running_loss / num_batch)
    
    # save checkpoints
    if (epc + 1) % diff_args.interval == 0:
        checkpoint = {
            'net':diffusion.model.state_dict(),
            'cemblayer':cemblayer.state_dict(),
            'optimizer':optimizer.state_dict(),
            'scheduler':warmUpScheduler.state_dict()
        }
        torch.save({'last_epoch':epc+1}, os.path.join(diff_args.moddir,'last_epoch.pt'))
        torch.save(checkpoint, os.path.join(diff_args.moddir, f'ckpt_{epc+1}_checkpoint.pt'))
    torch.cuda.empty_cache()

## Aux Privacy Model

In [None]:
class Aux_Priv_Classifier(nn.Module):
    def __init__(self, num_classes, z_dim):
        super(Aux_Priv_Classifier, self).__init__()

        self.flatten = torch.nn.Flatten()
        
        if args.dataset == 'mobi':
            self.fc1 = torch.nn.Conv2d(1, 8, kernel_size=2, stride=1, padding=1)
            self.fc2 = torch.nn.Conv2d(8, 16, kernel_size=2, stride=1, padding=1)
            self.fc3 = nn.Linear(16640+diff_args.cdim, 512)
        elif args.dataset == 'wifi':
            self.fc1 = torch.nn.Conv2d(1, 16, kernel_size=2, stride=1, padding=1)
            self.fc2 = torch.nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=1)
            self.fc3 = nn.Linear(241408+diff_args.cdim, 512) # b8
        elif args.dataset == 'motion':
            self.fc1 = torch.nn.Conv2d(1, 64, kernel_size=2, stride=1, padding=1)
            self.fc2 = torch.nn.Conv2d(64, 128, kernel_size=2, stride=1, padding=1)
            self.fc3 = nn.Linear(66560+diff_args.cdim, 512)
            
        self.fc4 = nn.Linear(512, 128)
        self.fc5 = nn.Linear(128, z_dim)
        self.fc6 = nn.Linear(z_dim, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x, emb):
        h1 = self.relu(self.fc1(x))
        h2 = self.flatten(self.relu(self.fc2(h1)))      
        h2 = torch.cat((h2, emb), dim=1)
        h3 = self.relu(self.fc3(h2))
        h4 = self.relu(self.fc4(h3))
        h5 = self.relu(self.fc5(h4))
        z = torch.nn.functional.normalize(h5, dim=-1)
        h6 = self.fc6(z)
        return h6, z

In [None]:
def train_priv_classifier(priv_model, surrogate_model, trainloader, optimizer, optimizer_mine, mi_estimator, w3, epochs, model_path, device):
    priv_model.train()
    criterion_base = torch.nn.BCELoss()
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        running_cls_loss = 0.0
        running_mi_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels, pub_labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            pub_labels = pub_labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            optimizer_mine.zero_grad()

            # condition private classifier on z_p
            with torch.no_grad():
                _output, emb = surrogate_model(inputs)

            logits, feature = priv_model(inputs, emb)

            # train mi_estimator separately due to the reversed loss
            mi_train_loss = mi_estimator(pub_labels, feature.detach())
            mi_train_loss.backward()
            optimizer_mine.step()
            outputs = torch.nn.functional.softmax(logits, dim=-1)
            
            cls_loss = criterion_base(outputs, labels)
            mi_loss = mi_estimator(pub_labels, feature)
            loss = cls_loss - w3 * mi_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(priv_model.parameters(), 1)
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            running_cls_loss += cls_loss.item()
            running_mi_loss += mi_loss.item()
            if i % 100 == 99:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}', f'cls_loss: {running_cls_loss / 100:.3f}', f'mi_loss: {running_mi_loss / 100:.3f}')
                running_loss = 0.0
                running_cls_loss = 0.0
                running_mi_loss = 0.0

    # Save trained models
    torch.save(priv_model.state_dict(), model_path)
    print('Auxiliary Privacy Classifier Training Completed')

In [None]:
# pub_attr: True or False, specify if this aux classifier is for public attribute or private attribute
def eval_priv_classifier(args, priv_model, surrogte_model, testloader, pub_attr, device):  
    correct = 0
    total = 0
    labs_pred = np.array([])
    labs_raw = np.array([])
    # since we're not training, we don't need to calculate the gradients for our outputs
    priv_model.eval()
    with torch.no_grad():
        for data in testloader:
            if args.dataset == 'mobi':
                images, act_labs, gen_labs, weight_labs = data
                if pub_attr:
                    labels = act_labs.to(device)
                else:
                    if args.private == 'gender':
                        labels = gen_labs.to(device)
                    elif args.private == 'weight':
                        labels = weight_labs.to(device)
    
            elif args.dataset == 'wifi':
                images, act_labs, height_labs, weight_labs = data
                if pub_attr:
                    labels = act_labs.to(device)
                else:
                    if args.private == 'height':
                        labels = height_labs.to(device)
                    elif args.private == 'weight':
                        labels = weight_labs.to(device)
    
            elif args.dataset == 'motion':
                images, act_labs, gen_labs = data
                if pub_attr:
                    labels = act_labs.to(device)
                else:
                    labels = gen_labs.to(device)        
            images = images.to(device)

            with torch.no_grad():
                _output, emb = surrogate_model(images)

            logits, feature = priv_model(images, emb)
            outputs = torch.nn.functional.softmax(logits, dim=-1)
            _, predicted = torch.max(outputs.data, 1)
            labs_pred = np.concatenate((labs_pred, predicted.detach().cpu().numpy()), axis=0 )
            labs_raw = np.concatenate((labs_raw, torch.argmax(labels, dim=1).detach().cpu().numpy()), axis=0)
            total += labels.size(0)
            correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
    print("Dataset:", args.dataset)
    if pub_attr:
        print("Public Attr: Activity")
    else:
        print("Private Attr:", args.private)
    print(f'Test Accuracy: {100 * correct / total} %')
    print_accu_confmat_f1score(Y_true=labs_raw, Y_pred=labs_pred)

In [None]:
surrogate_model.eval()

# to disentangle pub & priv attr
t = T(x_dim=args.num_public_attr, z_dim=diff_args.cdim).to(device)
mi_estimator = Mine(t, loss='mine').to(device)

priv_classifier = Aux_Priv_Classifier(args.num_private_attr, diff_args.cdim)    
priv_classifier = priv_classifier.to(device)

optimizer_priv = Adam(priv_classifier.parameters(), lr=2e-4)
optimizer_mine = Adam(mi_estimator.parameters(), lr=2e-4)

PATH_PRIV = diff_args.moddir+'/priv_classifier_'+args.dataset+'_'+args.private+'_'+str(trial)+'.pt'
print(PATH_PRIV)

In [None]:
# train_priv = True
train_priv = False

w3 = 4
if args.dataset == 'mobi':
    priv_epochs = 20
    w3 = 8
elif args.dataset == 'wifi':
    priv_epochs = 20
elif args.dataset == 'motion':
    priv_epochs = 80
    
if train_priv:
    train_priv_classifier(priv_classifier, surrogate_model, priv_train_loader, optimizer_priv, optimizer_mine, mi_estimator, w3, priv_epochs, PATH_PRIV, device)

In [None]:
if not train_priv:
    priv_classifier.load_state_dict(torch.load(PATH_PRIV, map_location=torch.device('cpu')))
    priv_classifier.to(device)
eval_priv_classifier(args, priv_classifier, surrogate_model, surrogate_test_loader, False, device)

In [None]:
# Tune privacy-utility trade-off post training
diff_args.w1 = 7.8
diff_args.w2 = 0.05

In [None]:
raw_test_imgs = None
recon_test_imgs = None
act_test_labs = None
priv_test_labs = None
weight_test_labs = None
is_first_batch = True

with torch.no_grad():
    for data in surrogate_test_loader:
        if args.dataset == 'mobi':
            images, labels, priv_labels, labels_weight = data
            if args.private == 'gender':
                priv_labels = priv_labels.to(device)
            elif args.private == 'weight':
                labels_weight = labels_weight.to(device)
        elif args.dataset == 'wifi':
            images, labels, priv_labels, labels_weight = data
            if args.private == 'height':
                priv_labels = priv_labels.to(device)
            elif args.private == 'weight':
                labels_weight = labels_weight.to(device)
        elif args.dataset == 'motion':
            images, labels, priv_labels = data
            
        images = images.to(device)
        labels = labels.to(device)    

        _output, emb = surrogate_model(images)
        emb = emb.detach()       
        cemb = cemblayer(emb)
        
        if args.dataset == 'mobi':
            genshape = (args.batch_size, 1, 6, 128)
        elif args.dataset =='wifi':
            genshape = (args.batch_size , 1, 90, 80)
        elif args.dataset =='motion':
            genshape = (args.batch_size , 1, 2, 128)

        sample_priv_labels = torch.argmax(priv_labels, dim=1)
        if args.dataset == 'mobi':
            if args.private == 'weight':
                sample_priv_labels = torch.argmax(labels_weight, dim=1)
        elif args.dataset == 'wifi':
            if args.private == 'weight':
                sample_priv_labels = torch.argmax(labels_weight, dim=1)

        if diff_args.ddim:
            generated = diffusion.ddim_sample(genshape, diff_args.num_steps, diff_args.eta, diff_args.select, priv_classifier=priv_classifier, priv_y=sample_priv_labels, emb=emb, w1=diff_args.w1, w2=diff_args.w2, cemb = cemb)
        else:
            raise ValueError('DDPM version of PrivDiffuser not implemented.')

        synthesized_img = generated
        
        if is_first_batch:
            raw_test_imgs = images.detach().cpu().numpy()
            recon_test_imgs = synthesized_img.detach().cpu().numpy()
            act_test_labs = labels.detach().cpu().numpy()
            if args.dataset != 'motion':
                weight_test_labs = labels_weight.detach().cpu().numpy()
            priv_test_labs = priv_labels.detach().cpu().numpy()
            is_first_batch = False
        else:
            raw_test_imgs = np.concatenate((raw_test_imgs, images.detach().cpu().numpy()), axis=0)
            recon_test_imgs = np.concatenate((recon_test_imgs, synthesized_img.detach().cpu().numpy()), axis=0)    
            act_test_labs = np.concatenate((act_test_labs, labels.detach().cpu().numpy()), axis=0)
            if args.dataset != 'motion':
                weight_test_labs = np.concatenate((weight_test_labs, labels_weight.detach().cpu().numpy()), axis=0)
            priv_test_labs = np.concatenate((priv_test_labs, priv_labels.detach().cpu().numpy()), axis=0)

In [None]:
# Compute MSE between obfuscated data and raw data
num_eval_pics = raw_test_imgs.shape[0]
mse = mean_squared_error(y_true=raw_test_imgs.reshape(num_eval_pics,-1), y_pred=recon_test_imgs.reshape(num_eval_pics,-1))
print("mse of %d synthesized images: \n%f" % (num_eval_pics, mse))

In [None]:
# Visualize raw data & reconstructed data
plot_comparison = False
if plot_comparison:
    plot_compare(raw_test_imgs, recon_test_imgs, 20)

## Evaluate Privacy & Utility Loss on obfuscated data

In [None]:
try:
    tensorflow.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

In [None]:
# Load HAR model & intrusive inference model
path_eval = "./eval_models/"
if args.dataset == 'mobi':
    eval_act_model = load_model(path_eval + "MobiAct/activity_model_DC.hdf5")
    eval_gender_model = load_model(path_eval + "MobiAct/gender_model_DC.hdf5")
    eval_weight_model = load_model(path_eval + "MobiAct/weight_model_DC.hdf5")
elif args.dataset == 'motion':
    eval_act_model = load_model(path_eval + "MotionSense/activity_model_mlp.hdf5")
    eval_gender_model = load_model(path_eval + "MotionSense/gender_model_mlp.hdf5")
elif args.dataset == 'wifi':
    eval_act_model = load_model(path_eval + "WiFi-HAR/wifi_activity_model_sit_80.hdf5")
    eval_weight_model = load_model(path_eval + "WiFi-HAR/wifi_weight_model_sit_80.hdf5")     

In [None]:
if args.dataset == 'mobi':
    public_txt_labels = ["wlk", "std", "jog", "ups"]
elif args.dataset =='motion':
    public_txt_labels = ["dws", "ups", "wlk", "jog"]
    
if args.private == 'gender':
    private_txt_labels = ["m", "f"]
elif args.private == 'weight':
    private_txt_labels = ["<=70", "70-90", ">90"]
weight_txt_labels = ["<=70", "70-90", ">90"]

if args.dataset == 'wifi':
    public_txt_labels = ["sit", "fall", "lie", "std"]
    if args.private == 'weight':
        private_txt_labels = ["<=80", ">80"]
    elif args.private == 'height':
        private_txt_labels = ["<=175", ">175"]

In [None]:
# Format obfuscated data
if args.dataset == 'mobi':
    synthesized_dataset = np.reshape(recon_test_imgs, (recon_test_imgs.shape[0], 6, 128, 1))
elif args.dataset == 'wifi':
    synthesized_dataset = np.reshape(recon_test_imgs, (recon_test_imgs.shape[0], 90, 80, 1))
elif args.dataset == 'motion':
    synthesized_dataset = np.reshape(recon_test_imgs, (recon_test_imgs.shape[0], 256))
synthesized_act_labels = act_test_labs
synthesized_private_labels = priv_test_labs
synthesized_weight_labels = weight_test_labs

In [None]:
# Save obfuscated data
np.save(diff_args.moddir + '/synthesized_dataset_ddim.npy', synthesized_dataset)
np.save(diff_args.moddir + '/raw_test_imgs_ddim.npy', raw_test_imgs)
np.save(diff_args.moddir + '/synthesized_act_labels_ddim.npy', synthesized_act_labels)
np.save(diff_args.moddir + '/synthesized_private_labels_ddim.npy', synthesized_private_labels)
np.save(diff_args.moddir + '/synthesized_weight_labels_ddim.npy', synthesized_weight_labels)

In [None]:
# # Load obfuscated data
# synthesized_dataset = np.load(diff_args.moddir + '/synthesized_dataset_ddim.npy') # save
# raw_test_imgs = np.load(diff_args.moddir + '/raw_test_imgs_ddim.npy')
# synthesized_act_labels = np.load(diff_args.moddir + '/synthesized_act_labels_ddim.npy') # save
# synthesized_private_labels = np.load(diff_args.moddir + '/synthesized_private_labels_ddim.npy') # save
# synthesized_weight_labels = np.load(diff_args.moddir + '/synthesized_weight_labels_ddim.npy') # save

In [None]:
print("Activity Identification:")
Y_act = eval_act_model.predict(synthesized_dataset, verbose=args.verbose)
Y_act_labels = np.argmax(Y_act, axis=1)  # generate predicted vector of labels
pred_act = to_categorical(Y_act_labels, num_classes=args.num_public_attr)
print_accu_score(Y_true=synthesized_act_labels, Y_pred=pred_act)
print_accu_confmat_f1score(Y_true=np.argmax(synthesized_act_labels, axis=1), Y_pred=Y_act_labels, txt_labels=public_txt_labels)

if args.dataset == 'mobi':
    print("Gender Identification:")
    Y_gen = eval_gender_model.predict(synthesized_dataset, verbose=args.verbose)
    Y_gen_labels = np.where(Y_gen > 0.5, 1, 0)
    pred_gen = to_categorical(Y_gen_labels, num_classes=2)        
    print_accu_score(Y_true=synthesized_private_labels, Y_pred=pred_gen)
    print_accu_confmat_f1score(Y_true=np.argmax(synthesized_private_labels, axis=1), Y_pred=Y_gen_labels, txt_labels=private_txt_labels)

    print("Weight Identification:")
    Y_gen = eval_weight_model.predict(synthesized_dataset, verbose=args.verbose)
    Y_gen_labels = np.argmax(Y_gen, axis=1)
    pred_gen = to_categorical(Y_gen_labels, num_classes=3)        
    print_accu_score(Y_true=synthesized_weight_labels, Y_pred=pred_gen)
    print_accu_confmat_f1score(Y_true=np.argmax(synthesized_weight_labels, axis=1), Y_pred=Y_gen_labels, txt_labels=weight_txt_labels)

elif args.dataset == 'motion':
    print("Gender Identification:")
    Y_gen = eval_gender_model.predict(synthesized_dataset, verbose=args.verbose)
    Y_gen_labels = np.where(Y_gen > 0.5, 1, 0)
    pred_gen = to_categorical(Y_gen_labels, num_classes=args.num_private_attr)        
    print_accu_score(Y_true=synthesized_private_labels, Y_pred=pred_gen)
    print_accu_confmat_f1score(Y_true=np.argmax(synthesized_private_labels, axis=1), Y_pred=Y_gen_labels, txt_labels=private_txt_labels)

elif args.dataset == 'wifi':
    print("Weight Identification:")
    Y_gen = eval_weight_model.predict(synthesized_dataset, verbose=args.verbose)
    Y_gen_labels = np.argmax(Y_gen, axis=1)
    pred_gen = to_categorical(Y_gen_labels, num_classes=2)  
    print_accu_score(Y_true=synthesized_weight_labels, Y_pred=pred_gen)
    print_accu_confmat_f1score(Y_true=np.argmax(synthesized_weight_labels, axis=1), Y_pred=Y_gen_labels, txt_labels=weight_txt_labels)