## The notebook contains code for 

### Dataset partitions
We require following partitions of the dataset:
- Private training data (size=10,000), 
- Reference data (size=10,000), 
- Test data (size=5,000), 
- Members to train the membership inference attack (MIA) model: 5,000 samples from the private training data
- Non-members to train MIA model: 5,000 samples that are not used in any of the above dataset partition


### DMP training
Specifically, we train an unprotected model on the private data, then distill the knowledge of the unprotected model in to randomly sampled reference data, use the distilled knowledge to train the final protected model

### A blackbox membership inference attack that evaluates MIA risk of the DMP models
Although there are multiple blackbox MIAs in the literature, we use the one proposed in "Comprehensive privacy analysis of deep learning: Passive and active white-box inference attacks against centralized and federated learning". We do not give the code for whitebox MIAs, as it is observed in many works (including ours) that whitebox MIAs are not really stronger than blackbox MIAs.

In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
from purchase_models import *
from purchase_normal_train import *
from purchase_private_train import *
from purchase_attack_train import *
from purchase_util import *

In [1]:
import os
cwd = os.getcwd()
print(cwd)

/home/nyanmaruk/Code/Research/Security/AAAI21-MIA-Defense/purchase


In [4]:
# data loading
data_loc='/purchase_data/dataset_purchase'
private_data_len=10000
ref_data_len=10000
te_len=5000
val_len=5000
attack_tr_len=10000
attack_te_len=10000


tr_frac=0.5 # we use 50% of private data as the members to train MIA model
val_frac=0.25 # we use 25% of private data as the members to validate MIA model
te_frac=0.25 # we use 25% of private data as the members to test MIA model

data_set = np.genfromtxt(data_loc, delimiter=',')
X = data_set[:,1:].astype(np.float64)
Y = (data_set[:,0]).astype(np.int32)-1

print('total data len: ',len(X))

if not os.path.isfile('./purchase_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./purchase_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./purchase_shuffle.pkl','rb'))


private_data=X[all_indices[:private_data_len]]
private_label=Y[all_indices[:private_data_len]]

ref_data=X[all_indices[private_data_len : (private_data_len + ref_data_len)]]
ref_label=Y[all_indices[private_data_len : (private_data_len + ref_data_len)]]

val_data=X[all_indices[(private_data_len + ref_data_len):(private_data_len + ref_data_len + val_len)]]
val_label=Y[all_indices[(private_data_len + ref_data_len):(private_data_len + ref_data_len + val_len)]]

te_data=X[all_indices[(private_data_len + ref_data_len + val_len):(private_data_len + ref_data_len + val_len + te_len)]]
te_label=Y[all_indices[(private_data_len + ref_data_len + val_len):(private_data_len + ref_data_len + val_len + te_len)]]

attack_te_data=X[all_indices[(private_data_len + ref_data_len + val_len+ te_len):(private_data_len + ref_data_len + val_len+ te_len + attack_te_len)]]
attack_te_label=Y[all_indices[(private_data_len + ref_data_len + val_len+ te_len):(private_data_len + ref_data_len + val_len+ te_len + attack_te_len)]]

attack_tr_data=X[all_indices[(private_data_len + ref_data_len + val_len+ te_len + attack_te_len):(private_data_len + ref_data_len + val_len+ te_len + attack_te_len + attack_tr_len)]]
attack_tr_label=Y[all_indices[(private_data_len + ref_data_len + val_len+ te_len + attack_te_len):(private_data_len + ref_data_len + val_len+ te_len + attack_te_len + attack_tr_len)]]

remaining_data=X[all_indices[(private_data_len + ref_data_len + val_len+ te_len + attack_te_len + attack_tr_len):]]
remaining_label=Y[all_indices[(private_data_len + ref_data_len + val_len+ te_len + attack_te_len + attack_tr_len):]]


# get private data and label tensors required to train the unprotected model
private_data_tensor=torch.from_numpy(private_data).type(torch.FloatTensor)
private_label_tensor=torch.from_numpy(private_label).type(torch.LongTensor)


# get reference data and label tensors required to distil the knowledge into the protected model
ref_indices=np.arange((ref_data_len))
ref_data_tensor=torch.from_numpy(ref_data).type(torch.FloatTensor)
ref_label_tensor=torch.from_numpy(ref_label).type(torch.LongTensor)


## Member tensors required to train, validate, and test the MIA model:
# get member data and label tensors required to train MIA model
mia_train_members_data_tensor=private_data_tensor[:int(tr_frac*private_data_len)]
mia_train_members_label_tensor=private_label_tensor[:int(tr_frac*private_data_len)]

# get member data and label tensors required to validate MIA model
mia_val_members_data_tensor=private_data_tensor[int(tr_frac*private_data_len):int((tr_frac+val_frac)*private_data_len)]
mia_val_members_label_tensor=private_label_tensor[int(tr_frac*private_data_len):int((tr_frac+val_frac)*private_data_len)]

# get member data and label tensors required to test MIA model
mia_test_members_data_tensor=private_data_tensor[int((tr_frac+val_frac)*private_data_len):]
mia_test_members_label_tensor=private_label_tensor[int((tr_frac+val_frac)*private_data_len):]



## Non-member tensors required to train, validate, and test the MIA model:
attack_tr_data_tensors = torch.from_numpy(attack_tr_data).type(torch.FloatTensor)
attack_tr_label_tensors = torch.from_numpy(attack_tr_label).type(torch.LongTensor)

# get non-member data and label tensors required to train 
mia_train_nonmembers_data_tensor = attack_tr_data_tensors[:int(tr_frac*private_data_len)]
mia_train_nonmembers_label_tensor = attack_tr_label_tensors[:int(tr_frac*private_data_len)]

# get member data and label tensors required to validate MIA model
mia_val_nonmembers_data_tensor = attack_tr_data_tensors[int(tr_frac*private_data_len):int((tr_frac+val_frac)*private_data_len)]
mia_val_nonmembers_label_tensor = attack_tr_label_tensors[int(tr_frac*private_data_len):int((tr_frac+val_frac)*private_data_len)]

# get member data and label tensors required to test MIA model
mia_test_nonmembers_data_tensor = attack_tr_data_tensors[int((tr_frac+val_frac)*private_data_len):]
mia_test_nonmembers_label_tensor = attack_tr_label_tensors[int((tr_frac+val_frac)*private_data_len):]



# get non-member data and label tensors required to test the MIA model
attack_te_data_tensor=torch.from_numpy(attack_te_data).type(torch.FloatTensor)
attack_te_label_tensor=torch.from_numpy(attack_te_label).type(torch.LongTensor)


## Tensors required to validate and test the unprotected and protected models
# get validation data and label tensors
val_data_tensor=torch.from_numpy(val_data).type(torch.FloatTensor)
val_label_tensor=torch.from_numpy(val_label).type(torch.LongTensor)

# get test data and label tensors
te_data_tensor=torch.from_numpy(te_data).type(torch.FloatTensor)
te_label_tensor=torch.from_numpy(te_label).type(torch.LongTensor)


print('tr len %d | mia_members tr %d val %d te %d | mia_nonmembers tr %d val %d te %d | ref len %d | val len %d | test len %d | attack te len %d | remaining data len %d'%
      (len(private_data_tensor),len(mia_train_members_data_tensor),len(mia_val_members_data_tensor),len(mia_test_members_data_tensor),
       len(mia_train_nonmembers_data_tensor), len(mia_val_nonmembers_data_tensor),len(mia_test_nonmembers_data_tensor),
       len(ref_data_tensor),len(val_data_tensor),len(te_data_tensor),len(attack_te_data_tensor), len(remaining_data)))



total data len:  197324
tr len 10000 | mia_members tr 5000 val 2500 te 2500 | mia_nonmembers tr 5000 val 2500 te 2500 | ref len 10000 | val len 5000 | test len 5000 | attack te len 10000 | remaining data len 147324


In [12]:
class PurchaseClassifier(nn.Module):
    def __init__(self,num_classes=100):
        super(PurchaseClassifier, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(600,1024),
            nn.Tanh(),
            nn.Linear(1024,512),
            nn.Tanh(),
            nn.Linear(512,256),
            nn.Tanh(),
            nn.Linear(256,128),
            nn.Tanh(),
        )
        self.classifier = nn.Linear(128,num_classes)
        
    def forward(self,inp):
        
        outputs=[]
        x=inp
        module_list =list(self.features.modules())[1:]
        for l in module_list:
            
            x = l(x)
            outputs.append(x)
        
        y = x.view(inp.size(0), -1)
        o = self.classifier(y)
        
        return o, outputs[-1].view(inp.size(0), -1), outputs[-4].view(inp.size(0), -1)
        

## check the size of the model

In [13]:
model=PurchaseClassifier()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
sum([np.prod(p.size()) for p in model_parameters])

1317348

In [16]:
def train_pub(train_data, labels, true_labels, model, t_softmax, optimizer, num_batchs=999999, batch_size=16, alpha=1):
    # switch to train mode
    model.train()
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    true_criterion=nn.CrossEntropyLoss()

    len_t = len(train_data)//batch_size
    if len(train_data) % batch_size:
        len_t += 1
    
    for ind in range(len_t):
        if ind > num_batchs:
            break

        inputs = train_data[ind*batch_size:(ind+1)*batch_size]
        targets = labels[ind*batch_size:(ind+1)*batch_size]
        true_targets=true_labels[ind*batch_size:(ind+1)*batch_size]


        inputs, targets, true_targets = inputs.cuda(), targets.cuda(), true_targets.cuda()
        
        inputs, targets, true_targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets), torch.autograd.Variable(true_targets)

        # compute output
        outputs, _, _ = model(inputs)

        loss = alpha*F.kl_div(F.log_softmax(outputs/t_softmax, dim=1), F.softmax(targets/t_softmax, dim=1)) + (1-alpha)*true_criterion(outputs,true_targets)
        
        # measure loss
        losses.update(loss.item(), inputs.size(0))
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg)

In [17]:
def train(train_data,labels,model,criterion,optimizer,epoch,use_cuda,num_batchs=999999,batch_size=32, uniform_reg=False):
    # switch to train mode
    model.train()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    len_t = len(train_data)//batch_size
    if len(train_data)%batch_size:
        len_t += 1
    
    
    for ind in range(len_t):
        if ind > num_batchs:
            break
        
        inputs = train_data[ind*batch_size:(ind+1)*batch_size]
        targets = labels[ind*batch_size:(ind+1)*batch_size]

        inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs,_,_ = model(inputs)
        
        if uniform_reg==True:
            uniform_ = (torch.ones(batch_size,outputs.shape[1])).cuda()
            t_softmax=1
            loss = criterion(outputs, targets)+100*F.kl_div(F.log_softmax(outputs/t_softmax, dim=1), F.softmax(uniform_/t_softmax, dim=1))
        else:
            loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg, top1.avg)

In [24]:
user_lr=0.0005
at_lr=0.0001
n_classes=100
t_softmax=1.0

In [9]:
checkpoint_dir='./dmp'

criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()


batch_size=128
gamma=0.5
num_epochs=100

model=PurchaseClassifier()
model=model.cuda()

optimizer=optim.Adam(model.parameters(), lr=user_lr, weight_decay=1e-3)

best_val_acc=0
best_test_acc=0

for epoch in range(num_epochs):
#     if epoch in schedule:
#         for param_group in optimizer.param_groups:
#             param_group['lr'] *= gamma
#             print('Epoch %d Local lr %f'%(epoch,param_group['lr']))

    train_loss, train_acc = train(private_data_tensor, private_label_tensor, model, criterion, optimizer, epoch, use_cuda, uniform_reg=False)

    val_loss, val_acc = test(val_data_tensor, val_label_tensor, model, criterion, use_cuda)

    is_best = val_acc >= best_val_acc

    best_val_acc=max(val_acc, best_val_acc)

    if is_best:
        _, best_test_acc = test(te_data_tensor,te_label_tensor,model,criterion,use_cuda)

    save_checkpoint_global(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'best_acc': best_val_acc,
            'optimizer': optimizer.state_dict(),
        },
        is_best,
        checkpoint=checkpoint_dir,
        filename='unprotected_model.pth.tar',
        best_filename='unprotected_model_best.pth.tar',
    )

    print('epoch %d | tr acc %.2f loss %.2f | val acc %.2f loss %.2f | best val acc %.2f | best te acc %.2f'
          %(epoch, train_acc, train_loss, val_acc, val_loss, best_val_acc, best_test_acc))



epoch 0 | tr acc 27.58 loss 2.96 | val acc 48.75 loss 2.00 | best val acc 48.75 | best te acc 47.95
epoch 1 | tr acc 53.47 loss 1.67 | val acc 53.70 loss 1.56 | best val acc 53.70 | best te acc 51.95
epoch 2 | tr acc 61.91 loss 1.28 | val acc 54.06 loss 1.42 | best val acc 54.06 | best te acc 52.95
epoch 3 | tr acc 66.88 loss 1.08 | val acc 57.09 loss 1.30 | best val acc 57.09 | best te acc 56.17
epoch 4 | tr acc 70.52 loss 0.93 | val acc 57.74 loss 1.29 | best val acc 57.74 | best te acc 57.38
epoch 5 | tr acc 72.95 loss 0.85 | val acc 60.91 loss 1.18 | best val acc 60.91 | best te acc 59.16
epoch 6 | tr acc 74.91 loss 0.78 | val acc 62.14 loss 1.14 | best val acc 62.14 | best te acc 60.55
epoch 7 | tr acc 77.92 loss 0.70 | val acc 60.91 loss 1.18 | best val acc 62.14 | best te acc 60.55
epoch 8 | tr acc 79.00 loss 0.65 | val acc 65.19 loss 1.03 | best val acc 65.19 | best te acc 64.57
epoch 9 | tr acc 83.08 loss 0.55 | val acc 67.24 loss 0.95 | best val acc 67.24 | best te acc 66.78


epoch 82 | tr acc 100.00 loss 0.06 | val acc 75.84 loss 0.71 | best val acc 76.57 | best te acc 75.64
epoch 83 | tr acc 100.00 loss 0.06 | val acc 76.00 loss 0.70 | best val acc 76.57 | best te acc 75.64
epoch 84 | tr acc 100.00 loss 0.06 | val acc 76.05 loss 0.70 | best val acc 76.57 | best te acc 75.64
epoch 85 | tr acc 100.00 loss 0.06 | val acc 75.96 loss 0.71 | best val acc 76.57 | best te acc 75.64
epoch 86 | tr acc 95.97 loss 0.18 | val acc 57.98 loss 1.45 | best val acc 76.57 | best te acc 75.64
epoch 87 | tr acc 91.51 loss 0.30 | val acc 72.27 loss 0.87 | best val acc 76.57 | best te acc 75.64
epoch 88 | tr acc 98.41 loss 0.11 | val acc 73.98 loss 0.80 | best val acc 76.57 | best te acc 75.64
epoch 89 | tr acc 99.47 loss 0.07 | val acc 74.16 loss 0.78 | best val acc 76.57 | best te acc 75.64
epoch 90 | tr acc 99.81 loss 0.06 | val acc 74.94 loss 0.75 | best val acc 76.57 | best te acc 75.64
epoch 91 | tr acc 99.94 loss 0.06 | val acc 74.96 loss 0.73 | best val acc 76.57 | best

In [13]:
# distil the knowledge of the unprotected model in the ref data

best_model=PurchaseClassifier().cuda()

resume_best=checkpoint_dir+'/unprotected_model_best.pth.tar'

assert os.path.isfile(resume_best), 'Error: no checkpoint directory found for best model'
checkpoint = os.path.dirname(resume_best)
checkpoint = torch.load(resume_best)
best_model.load_state_dict(checkpoint['state_dict'])

_,best_test = test(te_data_tensor, te_label_tensor, best_model, criterion, use_cuda)
_,best_train = test(private_data_tensor, private_label_tensor, best_model, criterion, use_cuda)
print('unprotected model: train acc %.4f test acc %.4f'%(best_train, best_test))

batch_size=100
all_outputs=[]

len_t = len(ref_data_tensor)//batch_size

for ind in range(len_t):
    inputs = ref_data_tensor[ind*batch_size:(ind+1)*batch_size]
    if use_cuda:
        inputs = inputs.cuda()
    inputs = torch.autograd.Variable(inputs)
    outputs,_,_ = best_model(inputs)
    all_outputs.append(outputs.data.cpu().numpy())

if len(ref_data_tensor)%batch_size:
    inputs=ref_data_tensor[-(len(target_ref_data_tensor)%batch_size):]
    if use_cuda:
        inputs = inputs.cuda()
    inputs = torch.autograd.Variable(inputs)
    outputs,_,_ = best_model(inputs)
    all_outputs.append(outputs.data.cpu().numpy())

final_outputs=np.concatenate(all_outputs)
distil_label_tensor=(torch.from_numpy(final_outputs).type(torch.FloatTensor))




# train final protected model via knowledge distillation

distil_model=PurchaseClassifier().cuda()
distil_test_criterion=nn.CrossEntropyLoss()

distil_schedule=[60, 90, 150]
distil_lr=.1
distil_epochs=200

distil_best_acc=0
best_distil_test_acc=0
gamma=.1
t_softmax=1

for epoch in range(distil_epochs):
    if epoch in distil_schedule:
        distil_lr *= gamma
        print('----> Epoch %d distillation lr %f'%(epoch,distil_lr))

    distil_optimizer=optim.SGD(distil_model.parameters(), lr=distil_lr, momentum=0.99, weight_decay=1e-5)

    distil_tr_loss = train_pub(ref_data_tensor, distil_label_tensor, ref_label_tensor, distil_model, t_softmax,
                               distil_optimizer, batch_size=16, alpha=1)

    tr_loss,tr_acc = test(private_data_tensor, private_label_tensor, distil_model, distil_test_criterion, use_cuda)
    
    val_loss,val_acc = test(val_data_tensor, val_label_tensor, distil_model, distil_test_criterion, use_cuda)

    distil_is_best = val_acc >= distil_best_acc

    distil_best_acc=max(val_acc, distil_best_acc)

    if distil_is_best:
        _,best_distil_test_acc = test(te_data_tensor, te_label_tensor, distil_model, distil_test_criterion, use_cuda)

    save_checkpoint_global(
        {
            'epoch': epoch,
            'state_dict': distil_model.state_dict(),
            'best_acc': distil_best_acc,
            'optimizer': distil_optimizer.state_dict(),
        },
        distil_is_best,
        checkpoint=checkpoint_dir,
        filename='protected_model.pth.tar',
        best_filename='protected_model_best.pth.tar',
    )

    print('epoch %d | distil loss %.4f | tr loss %.4f tr acc %.4f | val loss %.4f val acc %.4f | best val acc %.4f | best test acc %.4f'%(epoch,distil_tr_loss,tr_loss,tr_acc,val_loss,val_acc,distil_best_acc,best_distil_test_acc))

unprotected model: train acc 99.4692 test acc 75.6431
epoch 0 | distil loss 0.0330 | tr loss 2.6630 tr acc 23.0669 | val loss 2.6896 val acc 23.0105 | best val acc 23.0105 | best test acc 23.4124
epoch 1 | distil loss 0.0137 | tr loss 1.3279 tr acc 57.1414 | val loss 1.4066 val acc 54.5418 | best val acc 54.5418 | best test acc 53.3963
epoch 2 | distil loss 0.0078 | tr loss 1.2328 tr acc 58.2632 | val loss 1.3431 val acc 54.3810 | best val acc 54.5418 | best test acc 53.3963
epoch 3 | distil loss 0.0061 | tr loss 1.1856 tr acc 59.2849 | val loss 1.3257 val acc 54.9839 | best val acc 54.9839 | best test acc 54.4815
epoch 4 | distil loss 0.0055 | tr loss 1.1137 tr acc 62.0493 | val loss 1.2460 val acc 57.8376 | best val acc 57.8376 | best test acc 56.8931
epoch 5 | distil loss 0.0053 | tr loss 1.0195 tr acc 65.3846 | val loss 1.1799 val acc 58.8023 | best val acc 58.8023 | best test acc 58.8023
epoch 6 | distil loss 0.0044 | tr loss 0.9677 tr acc 67.3778 | val loss 1.1326 val acc 61.4349

epoch 57 | distil loss 0.0013 | tr loss 0.5861 tr acc 82.0012 | val loss 0.8575 val acc 70.4180 | best val acc 70.5788 | best test acc 71.5233
epoch 58 | distil loss 0.0014 | tr loss 0.6256 tr acc 80.0881 | val loss 0.8925 val acc 68.7902 | best val acc 70.5788 | best test acc 71.5233
epoch 59 | distil loss 0.0016 | tr loss 0.6249 tr acc 80.5789 | val loss 0.8846 val acc 69.2122 | best val acc 70.5788 | best test acc 71.5233
----> Epoch 60 distillation lr 0.010000
epoch 60 | distil loss 0.0010 | tr loss 0.5090 tr acc 85.3365 | val loss 0.7792 val acc 73.5330 | best val acc 73.5330 | best test acc 73.1310
epoch 61 | distil loss 0.0008 | tr loss 0.5016 tr acc 85.5869 | val loss 0.7726 val acc 73.7339 | best val acc 73.7339 | best test acc 73.3923
epoch 62 | distil loss 0.0007 | tr loss 0.4974 tr acc 85.8273 | val loss 0.7689 val acc 73.8545 | best val acc 73.8545 | best test acc 73.3923
epoch 63 | distil loss 0.0007 | tr loss 0.4941 tr acc 85.9575 | val loss 0.7662 val acc 73.9349 | best

epoch 114 | distil loss 0.0005 | tr loss 0.4612 tr acc 87.5000 | val loss 0.7478 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 115 | distil loss 0.0005 | tr loss 0.4611 tr acc 87.5000 | val loss 0.7478 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 116 | distil loss 0.0005 | tr loss 0.4611 tr acc 87.5100 | val loss 0.7478 val acc 74.4976 | best val acc 74.6182 | best test acc 73.9550
epoch 117 | distil loss 0.0005 | tr loss 0.4610 tr acc 87.5200 | val loss 0.7477 val acc 74.4976 | best val acc 74.6182 | best test acc 73.9550
epoch 118 | distil loss 0.0005 | tr loss 0.4610 tr acc 87.5200 | val loss 0.7477 val acc 74.4976 | best val acc 74.6182 | best test acc 73.9550
epoch 119 | distil loss 0.0005 | tr loss 0.4609 tr acc 87.5200 | val loss 0.7477 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 120 | distil loss 0.0005 | tr loss 0.4609 tr acc 87.5100 | val loss 0.7477 val acc 74.5177 | best val acc 74.6182 | best test acc 

epoch 171 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.4900 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 172 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.4900 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 173 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.5100 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 174 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.5100 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 175 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.5100 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 176 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.5100 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 73.9550
epoch 177 | distil loss 0.0005 | tr loss 0.4599 tr acc 87.5100 | val loss 0.7472 val acc 74.5177 | best val acc 74.6182 | best test acc 

In [14]:
# train final protected model via knowledge distillation

distil_model=PurchaseClassifier().cuda()
distil_test_criterion=nn.CrossEntropyLoss()

distil_schedule=[60, 90, 150]
distil_lr=.1
distil_epochs=200

distil_best_acc=0
best_distil_test_acc=0
gamma=.1
t_softmax=1

for epoch in range(distil_epochs):
    if epoch in distil_schedule:
        distil_lr *= gamma
        print('----> Epoch %d distillation lr %f'%(epoch,distil_lr))

    distil_optimizer=optim.SGD(distil_model.parameters(), lr=distil_lr, momentum=0.99, weight_decay=1e-5)

    distil_tr_loss = train_pub(ref_data_tensor, distil_label_tensor, ref_label_tensor, distil_model, t_softmax,
                               distil_optimizer, batch_size=32, alpha=1)

    tr_loss,tr_acc = test(private_data_tensor, private_label_tensor, distil_model, distil_test_criterion, use_cuda)
    
    val_loss,val_acc = test(val_data_tensor, val_label_tensor, distil_model, distil_test_criterion, use_cuda)

    distil_is_best = val_acc >= distil_best_acc

    distil_best_acc=max(val_acc, distil_best_acc)

    if distil_is_best:
        _,best_distil_test_acc = test(te_data_tensor, te_label_tensor, distil_model, distil_test_criterion, use_cuda)

    save_checkpoint_global(
        {
            'epoch': epoch,
            'state_dict': distil_model.state_dict(),
            'best_acc': distil_best_acc,
            'optimizer': distil_optimizer.state_dict(),
        },
        distil_is_best,
        checkpoint=checkpoint_dir,
        filename='protected_model.pth.tar',
        best_filename='protected_model_best.pth.tar',
    )

    print('epoch %d | distil loss %.4f | tr loss %.4f tr acc %.4f | val loss %.4f val acc %.4f | best val acc %.4f | best test acc %.4f'%(epoch,distil_tr_loss,tr_loss,tr_acc,val_loss,val_acc,distil_best_acc,best_distil_test_acc))

epoch 0 | distil loss 0.0383 | tr loss 4.2935 tr acc 3.5757 | val loss 4.3025 val acc 3.2958 | best val acc 3.2958 | best test acc 3.0949
epoch 1 | distil loss 0.0289 | tr loss 2.6946 tr acc 24.7596 | val loss 2.7269 val acc 23.8947 | best val acc 23.8947 | best test acc 23.6937
epoch 2 | distil loss 0.0171 | tr loss 1.8594 tr acc 44.6915 | val loss 1.9134 val acc 42.8457 | best val acc 42.8457 | best test acc 42.5844
epoch 3 | distil loss 0.0103 | tr loss 1.4389 tr acc 55.1082 | val loss 1.5171 val acc 52.1302 | best val acc 52.1302 | best test acc 51.9494
epoch 4 | distil loss 0.0075 | tr loss 1.1874 tr acc 62.7304 | val loss 1.2945 val acc 58.1793 | best val acc 58.1793 | best test acc 57.6166
epoch 5 | distil loss 0.0063 | tr loss 1.1429 tr acc 62.6903 | val loss 1.2655 val acc 58.2998 | best val acc 58.2998 | best test acc 57.2749
epoch 6 | distil loss 0.0057 | tr loss 1.0464 tr acc 66.0056 | val loss 1.1792 val acc 60.9325 | best val acc 60.9325 | best test acc 60.1889
epoch 7 | 

epoch 58 | distil loss 0.0021 | tr loss 0.6231 tr acc 80.5389 | val loss 0.8757 val acc 69.9960 | best val acc 71.1616 | best test acc 70.4180
epoch 59 | distil loss 0.0016 | tr loss 0.6362 tr acc 80.0581 | val loss 0.9010 val acc 70.0563 | best val acc 71.1616 | best test acc 70.4180
----> Epoch 60 distillation lr 0.010000
epoch 60 | distil loss 0.0011 | tr loss 0.5249 tr acc 84.7456 | val loss 0.7910 val acc 73.5932 | best val acc 73.5932 | best test acc 72.5080
epoch 61 | distil loss 0.0008 | tr loss 0.5218 tr acc 84.8958 | val loss 0.7888 val acc 73.4727 | best val acc 73.5932 | best test acc 72.5080
epoch 62 | distil loss 0.0008 | tr loss 0.5197 tr acc 84.8558 | val loss 0.7869 val acc 73.4325 | best val acc 73.5932 | best test acc 72.5080
epoch 63 | distil loss 0.0008 | tr loss 0.5179 tr acc 84.9259 | val loss 0.7853 val acc 73.5531 | best val acc 73.5932 | best test acc 72.5080
epoch 64 | distil loss 0.0008 | tr loss 0.5163 tr acc 85.0260 | val loss 0.7840 val acc 73.5732 | best

epoch 115 | distil loss 0.0006 | tr loss 0.4919 tr acc 86.3281 | val loss 0.7700 val acc 73.9751 | best val acc 74.0555 | best test acc 73.2114
epoch 116 | distil loss 0.0006 | tr loss 0.4918 tr acc 86.3381 | val loss 0.7700 val acc 73.9550 | best val acc 74.0555 | best test acc 73.2114
epoch 117 | distil loss 0.0006 | tr loss 0.4918 tr acc 86.3381 | val loss 0.7700 val acc 73.9550 | best val acc 74.0555 | best test acc 73.2114
epoch 118 | distil loss 0.0006 | tr loss 0.4918 tr acc 86.3281 | val loss 0.7700 val acc 73.9751 | best val acc 74.0555 | best test acc 73.2114
epoch 119 | distil loss 0.0006 | tr loss 0.4917 tr acc 86.3281 | val loss 0.7700 val acc 73.9952 | best val acc 74.0555 | best test acc 73.2114
epoch 120 | distil loss 0.0006 | tr loss 0.4917 tr acc 86.3281 | val loss 0.7700 val acc 73.9952 | best val acc 74.0555 | best test acc 73.2114
epoch 121 | distil loss 0.0006 | tr loss 0.4916 tr acc 86.3281 | val loss 0.7700 val acc 74.0354 | best val acc 74.0555 | best test acc 

epoch 172 | distil loss 0.0006 | tr loss 0.4908 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 73.1913
epoch 173 | distil loss 0.0006 | tr loss 0.4908 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 73.1913
epoch 174 | distil loss 0.0006 | tr loss 0.4908 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 73.1913
epoch 175 | distil loss 0.0006 | tr loss 0.4908 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 73.1913
epoch 176 | distil loss 0.0006 | tr loss 0.4907 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 73.1913
epoch 177 | distil loss 0.0006 | tr loss 0.4907 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 73.1913
epoch 178 | distil loss 0.0006 | tr loss 0.4907 tr acc 86.3782 | val loss 0.7699 val acc 73.8344 | best val acc 74.0555 | best test acc 

In [20]:
# distil the knowledge of the unprotected model in the ref data
checkpoint_dir='./dmp'
criterion=nn.CrossEntropyLoss()

best_model=PurchaseClassifier().cuda()

resume_best=checkpoint_dir+'/unprotected_model_best.pth.tar'

assert os.path.isfile(resume_best), 'Error: no checkpoint directory found for best model'
checkpoint = os.path.dirname(resume_best)
checkpoint = torch.load(resume_best)
best_model.load_state_dict(checkpoint['state_dict'])

_,best_test = test(te_data_tensor, te_label_tensor, best_model, criterion, use_cuda)
_,best_train = test(private_data_tensor, private_label_tensor, best_model, criterion, use_cuda)
print('unprotected model: train acc %.4f test acc %.4f'%(best_train, best_test))

batch_size=100
all_outputs=[]

len_t = len(ref_data_tensor)//batch_size

for ind in range(len_t):
    inputs = ref_data_tensor[ind*batch_size:(ind+1)*batch_size]
    if use_cuda:
        inputs = inputs.cuda()
    inputs = torch.autograd.Variable(inputs)
    outputs,_,_ = best_model(inputs)
    all_outputs.append(outputs.data.cpu().numpy())

if len(ref_data_tensor)%batch_size:
    inputs=ref_data_tensor[-(len(target_ref_data_tensor)%batch_size):]
    if use_cuda:
        inputs = inputs.cuda()
    inputs = torch.autograd.Variable(inputs)
    outputs,_,_ = best_model(inputs)
    all_outputs.append(outputs.data.cpu().numpy())

final_outputs=np.concatenate(all_outputs)
distil_label_tensor=(torch.from_numpy(final_outputs).type(torch.FloatTensor))



# train final protected model via knowledge distillation

distil_model=PurchaseClassifier().cuda()
distil_test_criterion=nn.CrossEntropyLoss()

distil_schedule=[60, 90, 150]
distil_lr=.1
distil_epochs=200

distil_best_acc=0
best_distil_test_acc=0
gamma=.1
t_softmax=1

for epoch in range(distil_epochs):
    if epoch in distil_schedule:
        distil_lr *= gamma
        print('----> Epoch %d distillation lr %f'%(epoch,distil_lr))

    distil_optimizer=optim.SGD(distil_model.parameters(), lr=distil_lr, momentum=0.99, weight_decay=1e-5)

    distil_tr_loss = train_pub(ref_data_tensor, distil_label_tensor, ref_label_tensor, distil_model, t_softmax,
                               distil_optimizer, batch_size=32, alpha=1)

    tr_loss,tr_acc = test(private_data_tensor, private_label_tensor, distil_model, distil_test_criterion, use_cuda)
    
    val_loss,val_acc = test(val_data_tensor, val_label_tensor, distil_model, distil_test_criterion, use_cuda)

    distil_is_best = val_acc >= distil_best_acc

    distil_best_acc=max(val_acc, distil_best_acc)

    if distil_is_best:
        _,best_distil_test_acc = test(te_data_tensor, te_label_tensor, distil_model, distil_test_criterion, use_cuda)

    save_checkpoint_global(
        {
            'epoch': epoch,
            'state_dict': distil_model.state_dict(),
            'best_acc': distil_best_acc,
            'optimizer': distil_optimizer.state_dict(),
        },
        distil_is_best,
        checkpoint=checkpoint_dir,
        filename='protected_model.pth.tar',
        best_filename='protected_model_best.pth.tar',
    )

    print('epoch %d | distil loss %.4f | tr loss %.4f tr acc %.4f | val loss %.4f val acc %.4f | best val acc %.4f | best test acc %.4f'%(epoch,distil_tr_loss,tr_loss,tr_acc,val_loss,val_acc,distil_best_acc,best_distil_test_acc))

unprotected model: train acc 99.4692 test acc 75.6431
epoch 0 | distil loss 0.0383 | tr loss 4.3424 tr acc 3.1951 | val loss 4.3512 val acc 3.0547 | best val acc 3.0547 | best test acc 2.8135
epoch 1 | distil loss 0.0305 | tr loss 2.8379 tr acc 23.4575 | val loss 2.8673 val acc 22.1865 | best val acc 22.1865 | best test acc 22.5884
epoch 2 | distil loss 0.0174 | tr loss 1.7843 tr acc 48.2973 | val loss 1.8450 val acc 46.6037 | best val acc 46.6037 | best test acc 46.2822
epoch 3 | distil loss 0.0100 | tr loss 1.3768 tr acc 58.1130 | val loss 1.4635 val acc 54.8834 | best val acc 54.8834 | best test acc 54.2805
epoch 4 | distil loss 0.0072 | tr loss 1.2242 tr acc 61.3181 | val loss 1.3333 val acc 57.2749 | best val acc 57.2749 | best test acc 56.8127
epoch 5 | distil loss 0.0061 | tr loss 1.0986 tr acc 64.5833 | val loss 1.2267 val acc 59.1439 | best val acc 59.1439 | best test acc 58.4003
epoch 6 | distil loss 0.0055 | tr loss 1.0429 tr acc 65.5148 | val loss 1.1813 val acc 60.2894 | b

epoch 58 | distil loss 0.0015 | tr loss 0.6099 tr acc 80.6891 | val loss 0.8821 val acc 70.1969 | best val acc 71.0611 | best test acc 71.0611
epoch 59 | distil loss 0.0016 | tr loss 0.6243 tr acc 80.1082 | val loss 0.8899 val acc 70.1166 | best val acc 71.0611 | best test acc 71.0611
----> Epoch 60 distillation lr 0.010000
epoch 60 | distil loss 0.0011 | tr loss 0.5279 tr acc 84.5853 | val loss 0.7915 val acc 73.1913 | best val acc 73.1913 | best test acc 72.4879
epoch 61 | distil loss 0.0008 | tr loss 0.5202 tr acc 84.8858 | val loss 0.7852 val acc 73.3521 | best val acc 73.3521 | best test acc 72.6286
epoch 62 | distil loss 0.0008 | tr loss 0.5181 tr acc 84.9860 | val loss 0.7837 val acc 73.3320 | best val acc 73.3521 | best test acc 72.6286
epoch 63 | distil loss 0.0008 | tr loss 0.5162 tr acc 85.1162 | val loss 0.7822 val acc 73.2315 | best val acc 73.3521 | best test acc 72.6286
epoch 64 | distil loss 0.0008 | tr loss 0.5146 tr acc 85.2163 | val loss 0.7812 val acc 73.2315 | best

epoch 115 | distil loss 0.0006 | tr loss 0.4921 tr acc 85.9976 | val loss 0.7682 val acc 73.5732 | best val acc 73.5732 | best test acc 73.2114
epoch 116 | distil loss 0.0006 | tr loss 0.4921 tr acc 86.0076 | val loss 0.7682 val acc 73.5732 | best val acc 73.5732 | best test acc 73.2114
epoch 117 | distil loss 0.0006 | tr loss 0.4920 tr acc 86.0076 | val loss 0.7682 val acc 73.5732 | best val acc 73.5732 | best test acc 73.2114
epoch 118 | distil loss 0.0006 | tr loss 0.4920 tr acc 86.0076 | val loss 0.7682 val acc 73.5732 | best val acc 73.5732 | best test acc 73.2114
epoch 119 | distil loss 0.0006 | tr loss 0.4920 tr acc 86.0076 | val loss 0.7682 val acc 73.5330 | best val acc 73.5732 | best test acc 73.2114
epoch 120 | distil loss 0.0006 | tr loss 0.4919 tr acc 85.9976 | val loss 0.7682 val acc 73.5330 | best val acc 73.5732 | best test acc 73.2114
epoch 121 | distil loss 0.0006 | tr loss 0.4919 tr acc 86.0176 | val loss 0.7681 val acc 73.5330 | best val acc 73.5732 | best test acc 

epoch 172 | distil loss 0.0006 | tr loss 0.4914 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 73.2114
epoch 173 | distil loss 0.0006 | tr loss 0.4914 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 73.2114
epoch 174 | distil loss 0.0006 | tr loss 0.4914 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 73.2114
epoch 175 | distil loss 0.0006 | tr loss 0.4914 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 73.2114
epoch 176 | distil loss 0.0006 | tr loss 0.4914 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 73.2114
epoch 177 | distil loss 0.0006 | tr loss 0.4914 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 73.2114
epoch 178 | distil loss 0.0006 | tr loss 0.4913 tr acc 86.0276 | val loss 0.7680 val acc 73.4727 | best val acc 73.5732 | best test acc 

In [21]:
class InferenceAttack_BB(nn.Module):
    def __init__(self,num_classes):
        self.num_classes=num_classes
        super(InferenceAttack_BB, self).__init__()
        
        self.features=nn.Sequential(
            nn.Linear(100,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,64),
            nn.ReLU(),
            )

        self.labels=nn.Sequential(
           nn.Linear(num_classes,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            )

        self.loss=nn.Sequential(
           nn.Linear(1,num_classes),
            nn.ReLU(),
            nn.Linear(num_classes,64),
            nn.ReLU(),
            )
        
        self.combine=nn.Sequential(
            nn.Linear(64*3,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,1),
            )

        for key in self.state_dict():
            # print (key)
            if key.split('.')[-1] == 'weight':    
                nn.init.normal_(self.state_dict()[key], std=0.01)
                
            elif key.split('.')[-1] == 'bias':
                self.state_dict()[key][...] = 0
        self.output= nn.Sigmoid()
    
    def forward(self,x1,one_hot_labels,loss):

        out_x1 = self.features(x1)
        
        out_l = self.labels(one_hot_labels)
        
        out_loss= self.loss(loss)

        is_member =self.combine( torch.cat((out_x1,out_l,out_loss),1))
        
        return self.output(is_member)

In [22]:
def attack_bb(train_data, labels, attack_data, attack_label, model, inference_model, classifier_criterion, classifier_criterion_noreduct, criterion_attck, classifier_optimizer,
              optimizer, epoch, use_cuda, num_batchs=1000, is_train=False, batch_size=64):
    global best_acc

    losses = AverageMeter()
    top1 = AverageMeter()
    mtop1_a = AverageMeter()
    mtop5_a = AverageMeter()
    inference_model.eval()
    
    skip_batch=0
    
    if is_train:
        inference_model.train()
    
    model.eval()
    
    batch_size = batch_size//2
    #len_t =  min((len(attack_data)//batch_size) ,(len(train_data)//batch_size))-1
    
    len_t = len(train_data)//batch_size
    if len(train_data)%batch_size:
        len_t += 1

    for ind in range(skip_batch, len_t):

        if ind >= skip_batch+num_batchs:
            break
        
        if ind > (len(attack_data)//batch_size)-1 :
            ind = ind % (len(attack_data)//batch_size)

        tr_input = train_data[ind*batch_size:(ind+1)*batch_size]
        tr_target = labels[ind*batch_size:(ind+1)*batch_size]
        
        if ind > (len(attack_data)//batch_size)-1 :
            ind=ind%(len(attack_data)//batch_size)

        te_input = attack_data[ind*batch_size:(ind+1)*batch_size]
        te_target = attack_label[ind*batch_size:(ind+1)*batch_size]
        
        tr_input, tr_target = tr_input.cuda(), tr_target.cuda()
        te_input , te_target = te_input.cuda(), te_target.cuda()

        v_tr_input, v_tr_target = torch.autograd.Variable(tr_input), torch.autograd.Variable(tr_target)
        v_te_input, v_te_target = torch.autograd.Variable(te_input), torch.autograd.Variable(te_target)

        
        # compute output
        model_input = torch.cat((v_tr_input, v_te_input))
        
        pred_outputs, _, _ = model(model_input)
        
        infer_input= torch.cat((v_tr_target,v_te_target))
        
        one_hot_tr = torch.from_numpy(np.zeros(pred_outputs.size())).cuda().type(torch.cuda.FloatTensor)
        target_one_hot_tr = one_hot_tr.scatter_(1, infer_input.type(torch.cuda.LongTensor).view([-1,1]).data,1)

        infer_input_one_hot = torch.autograd.Variable(target_one_hot_tr)

        loss_= classifier_criterion_noreduct(pred_outputs, infer_input).view([-1,1])
        #torch.autograd.Variable(torch.from_numpy(c.view([-1,1]).data.cpu().numpy()).cuda())

        preds = torch.autograd.Variable(torch.from_numpy(pred_outputs.data.cpu().numpy()).cuda())
        member_output = inference_model(pred_outputs, infer_input_one_hot, loss_)

        is_member_labels = torch.from_numpy(np.reshape(np.concatenate((np.zeros(v_tr_input.size(0)),np.ones(v_te_input.size(0)))),[-1,1])).cuda()
        
        v_is_member_labels = torch.autograd.Variable(is_member_labels).type(torch.cuda.FloatTensor)

        loss = criterion_attck(member_output, v_is_member_labels)

        # measure accuracy and record loss
        prec1=np.mean((member_output.data.cpu().numpy() >0.5)==v_is_member_labels.data.cpu().numpy())
        losses.update(loss.item(), model_input.size(0))
        top1.update(prec1, model_input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if is_train:
            loss.backward()
            optimizer.step()

        # plot progress
        if False and ind%10==0:
            print  ('({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | | Loss: {loss:.4f} | top1: {top1: .4f} '.format(
                    batch=ind ,
                    size=len_t,
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    ))

    return (losses.avg, top1.avg)

In [27]:
at_lr=0.0005
at_schedule=[100]
at_gamma=0.1
n_classes=100
criterion_classifier = nn.CrossEntropyLoss(reduction='none')
attack_criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

best_at_val_acc=0
best_at_test_acc=0
attack_epochs=200

print('\n================ Evaluating DMP model against BB meminf attack  ================\n')

resume_best=checkpoint_dir+'/protected_model_best.pth.tar'

attack_model = InferenceAttack_BB(n_classes)
attack_model = attack_model.cuda()
attack_optimizer = optim.Adam(attack_model.parameters(),lr=at_lr)

best_model=PurchaseClassifier()
best_model=best_model.cuda()
best_opt=optim.Adam(best_model.parameters(), lr=user_lr)

assert os.path.isfile(resume_best), 'Error: no checkpoint directory %s found for best model'%resume_best
checkpoint = os.path.dirname(resume_best)
checkpoint = torch.load(resume_best)
best_model.load_state_dict(checkpoint['state_dict'])
best_opt.load_state_dict(checkpoint['optimizer'])

_, check_test_acc = test(te_data_tensor,te_label_tensor,best_model,criterion,use_cuda)
_, check_val_acc = test(val_data_tensor,val_label_tensor,best_model,criterion,use_cuda)
_, check_train_acc = test(private_data_tensor,private_label_tensor,best_model,criterion,use_cuda)

print('Private model | train acc %.4f | val acc %.4f | test acc %.4f'%(check_train_acc,check_val_acc,check_test_acc))

# np.random.shuffle(ref_indices)
# at_tr_data_tensor=torch.from_numpy(ref_data[ref_indices[:target_tr_len]]).type(torch.FloatTensor)
# at_tr_label_tensor=torch.from_numpy(ref_label[ref_indices[:target_tr_len]]).type(torch.LongTensor)

for epoch in range(attack_epochs):
    if epoch in at_schedule:
        for param_group in attack_optimizer.param_groups:
            param_group['lr'] *= at_gamma
            print('Epoch %d Local lr %f'%(epoch,param_group['lr']))

    at_loss, at_acc = attack_bb(mia_train_members_data_tensor, mia_train_members_label_tensor,
                                mia_train_nonmembers_data_tensor, mia_train_nonmembers_label_tensor,
                                best_model, attack_model, criterion, criterion_classifier, attack_criterion, best_opt,
                                attack_optimizer, epoch, use_cuda, is_train=True, batch_size=64)

    at_val_loss, at_val_acc = attack_bb(mia_val_members_data_tensor, mia_val_members_label_tensor,
                                        mia_val_nonmembers_data_tensor, mia_val_nonmembers_label_tensor,
                                        best_model, attack_model, criterion, criterion_classifier, attack_criterion, best_opt,
                                        attack_optimizer, epoch, use_cuda, is_train=False, batch_size=64)

    is_best = at_val_acc > best_at_val_acc

    if is_best:
        at_test_loss, best_at_test_acc = attack_bb(mia_test_members_data_tensor, mia_test_members_label_tensor,
                                                   mia_test_nonmembers_data_tensor, mia_test_nonmembers_label_tensor, 
                                                   best_model, attack_model, criterion, criterion_classifier, attack_criterion, best_opt,
                                                   attack_optimizer, epoch, use_cuda, is_train=False, batch_size=64)


    best_at_val_acc = max(best_at_val_acc, at_val_acc)
    
    print('protected_model | epoch %d | attack val acc %.4f | best val acc %.4f | test acc %.4f'%(epoch, at_val_acc, best_at_val_acc, best_at_test_acc) )
    
    



Private model | train acc 86.0076 | val acc 73.5732 | test acc 73.2114
protected_model | epoch 0 | attack val acc 0.5641 | best val acc 0.5641 | test acc 0.5742
protected_model | epoch 1 | attack val acc 0.5649 | best val acc 0.5649 | test acc 0.5837
protected_model | epoch 2 | attack val acc 0.5690 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 3 | attack val acc 0.5676 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 4 | attack val acc 0.5653 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 5 | attack val acc 0.5680 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 6 | attack val acc 0.5653 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 7 | attack val acc 0.5674 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 8 | attack val acc 0.5676 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 9 | attack val acc 0.5645 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 10 

protected_model | epoch 89 | attack val acc 0.5386 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 90 | attack val acc 0.5380 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 91 | attack val acc 0.5326 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 92 | attack val acc 0.5301 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 93 | attack val acc 0.5362 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 94 | attack val acc 0.5334 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 95 | attack val acc 0.5376 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 96 | attack val acc 0.5352 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 97 | attack val acc 0.5370 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 98 | attack val acc 0.5342 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 99 | attack val acc 0.5318 | best val acc 0.5690 | test acc 0.5916

protected_model | epoch 178 | attack val acc 0.5188 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 179 | attack val acc 0.5206 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 180 | attack val acc 0.5182 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 181 | attack val acc 0.5176 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 182 | attack val acc 0.5216 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 183 | attack val acc 0.5198 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 184 | attack val acc 0.5235 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 185 | attack val acc 0.5227 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 186 | attack val acc 0.5206 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 187 | attack val acc 0.5202 | best val acc 0.5690 | test acc 0.5916
protected_model | epoch 188 | attack val acc 0.5174 | best val acc 0.5690 | test