In [1]:
from torch.utils.data.dataset import Dataset
import torch
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score
from abc import ABCMeta, abstractmethod
class AbstractAttacker(metaclass=ABCMeta):
    def __init__(self, splitnn):
        """attacker against SplitNN
        Args:
            splitnn: SplitNN
        """
        self.splitnn = splitnn

    def fit(self):
        pass

    @abstractmethod
    def attack(self):
        pass


class NormAttack(AbstractAttacker):
    def __init__(self, splitnn):
        """Class that implement normattack
        Args:
            splitnn (attack_splitnn.splitnn.SplitNN): target splotnn model
        """
        super().__init__(splitnn)
        self.splitnn = splitnn

    def attack(self, dataloader, criterion, device):
        """Culculate leak_auc on the given SplitNN model
           reference: https://arxiv.org/abs/2102.08504
        Args:
            dataloader (torch dataloader): dataloader for evaluation
            criterion: loss function for training
            device: cpu or GPU
        Returns:
            score: culculated leak auc
        """
        epoch_labels = []
        epoch_g_norm = []
        for i, data in enumerate(dataloader, 0):

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = self.splitnn(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            self.splitnn.backward()

            grad_from_server = self.splitnn.client.grad_from_server
            g_norm = grad_from_server.pow(2).sum(dim=1).sqrt()
            epoch_labels.append(labels)
            epoch_g_norm.append(g_norm)

        epoch_labels = torch.cat(epoch_labels)
        epoch_g_norm = torch.cat(epoch_g_norm)
        score = roc_auc_score(epoch_labels, epoch_g_norm.view(-1, 1))
        return score
class DataSet(Dataset):
    """This class allows you to convert numpy.array to torch.Dataset
    Args:
        x (np.array):
        y (np.array):
        transform (torch.transform):
    Attriutes
        x (np.array):
        y (np.array):
        transform (torch.transform):
    """

    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform

    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]

        if self.transform is not None:
            x = self.transform(x)
        return x, y

    def __len__(self):
        """get the number of rows of self.x
        """
        return len(self.x)


def torch_roc_auc_score(label, pred):
    return roc_auc_score(label.cpu().detach().numpy(),
                         pred.cpu().detach().numpy())

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from ucimlrepo import fetch_ucirepo 
  
# fetch dataset 
spambase = fetch_ucirepo(id=94) 
  
# data (as pandas dataframes) 
X = spambase.data.features 
y = spambase.data.targets 
  
import pandas as pd
raw_df = pd.concat([X, y], axis=1)
raw_df = raw_df.rename(columns={raw_df.columns[-1]: 'label'})
scaler = preprocessing.StandardScaler()
raw_df.iloc[:,:-1] = pd.DataFrame(scaler.fit_transform(raw_df.iloc[:,:-1]), columns = raw_df.iloc[:,:-1].columns)
raw_df_neg = raw_df[raw_df["label"] == 0]
raw_df_pos = raw_df[raw_df["label"] == 1]
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    raw_df.shape[0],raw_df_pos.shape[0], 100 * raw_df_pos.shape[0] / raw_df.shape[0]))


Examples:
    Total: 4601
    Positive: 1813 (39.40% of total)



In [5]:
config = {
    "batch_size":1028
}
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
train_df, test_df = train_test_split(raw_df, test_size=0.3)
# train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df['label'])
bool_train_labels = train_labels != 0
# val_labels = np.array(val_df)
test_labels = np.array(test_df['label'])

train_features = np.array(train_df.drop(['label'],axis=1))
# val_features = np.array(val_df)
test_features = np.array(test_df.drop(['label'],axis=1))
print('Training labels shape:', train_labels.shape)
# print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
# print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
train_dataset = DataSet(train_features,
                        train_labels.astype(np.float64).reshape(-1, 1))
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=config["batch_size"],
                                           shuffle=True)

test_dataset = DataSet(test_features,
                       test_labels.astype(np.float64).reshape(-1, 1))
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=config["batch_size"],
                                          shuffle=True)

cuda:0
Training labels shape: (3220,)
Test labels shape: (1381,)
Training features shape: (3220, 57)
Test features shape: (1381, 57)


In [7]:
class FirstNet(nn.Module):
    def __init__(self,hidden_dim = 10):
        super(FirstNet, self).__init__()        
        self.L1 = nn.Linear(train_features.shape[-1],
                            hidden_dim)
        self.L2 = nn.Linear(hidden_dim,
                            1)

    def forward(self, x):
        x = self.L1(x)
        x = nn.functional.leaky_relu(x)
        x = self.L2(x)
        x = torch.sigmoid(x)
        # x = nn.functional.leaky_relu(x)
        return x
    
class SecondNet(nn.Module):
    def __init__(self,hidden_dim = 10):
        super(SecondNet, self).__init__()        
        self.L1 = nn.Linear(1,
                            hidden_dim)
        self.L2 = nn.Linear(hidden_dim,
                            1)

    def forward(self, x):
        x = self.L1(x)
        x = nn.functional.leaky_relu(x)
        x = self.L2(x)
        # x = nn.functional.leaky_relu(x)
        x = torch.sigmoid(x)
        return x

    
def torch_auc(label, pred):
    return roc_auc_score(label.cpu() .detach().numpy(),
                         pred.cpu() .detach().numpy())

In [8]:
#SplitNN
import torch
class Client(torch.nn.Module):
    def __init__(self, client_model):
        super().__init__()
        """class that expresses the Client on SplitNN
        Args:
            client_model (torch model): client-side model
        Attributes:
            client_model (torch model): cliet-side model
            client_side_intermidiate (torch.Tensor): output of
                                                     client-side model
            grad_from_server
        """

        self.client_model = client_model
        self.client_side_intermidiate = None
        self.grad_from_server = None

    def forward(self, inputs):
        """client-side feed forward network
        Args:
            inputs (torch.Tensor): the input data
        Returns:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model which the client sent
                                                   to the server
        """

        self.client_side_intermidiate = self.client_model(inputs)
        # send intermidiate tensor to the server
        intermidiate_to_server = self.client_side_intermidiate.detach()\
            .requires_grad_()

        return intermidiate_to_server

    def client_backward(self, grad_from_server):
        """client-side back propagation
        Args:
            grad_from_server: gradient which the server send to the client
        """
        self.grad_from_server = grad_from_server
        self.client_side_intermidiate.backward(grad_from_server)

    def train(self):
        self.client_model.train()

    def eval(self):
        self.client_model.eval()


class Server(torch.nn.Module):
    def __init__(self, server_model):
        super().__init__()
        """class that expresses the Server on SplitNN
        Args:
            server_model (torch model): server-side model
        Attributes:
            server_model (torch model): server-side model
            intermidiate_to_server:
            grad_to_client
        """
        self.server_model = server_model

        self.intermidiate_to_server = None
        self.grad_to_client = None

    def forward(self, intermidiate_to_server):
        """server-side training
        Args:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model
        Returns:
            outputs (torch.Tensor): outputs of server-side model
        """
        self.intermidiate_to_server = intermidiate_to_server
        outputs = self.server_model(intermidiate_to_server)

        return outputs

    def server_backward(self):
        self.grad_to_client = self.intermidiate_to_server.grad.clone()
        return self.grad_to_client

    def train(self):
        self.server_model.train()

    def eval(self):
        self.server_model.eval()


class SplitNN(torch.nn.Module):
    def __init__(self, client, server,
                 client_optimizer, server_optimizer
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.client = client
        self.server = server
        self.client_optimizer = client_optimizer
        self.server_optimizer = server_optimizer
        self.grad_to_client =None

        self.intermidiate_to_server = None

    def forward(self, inputs,labels):
        # execute client - feed forward network
        self.labels=labels
        self.intermidiate_to_server = self.client(inputs)
        # execute server - feed forward netwoek
        outputs = self.server(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward()

        return outputs,self.intermidiate_to_server

    def backward(self):
        # execute server - back propagation
        self.grad_to_client = self.server.server_backward()
        # execute client - back propagation
        # if model=='Marvell':
        #   grad_to_client=KL_gradient_perturb_function_creator(self.label,grad_to_client)
        
        self.client.client_backward(self.grad_to_client)

    def zero_grads(self):
        self.client_optimizer.zero_grad()
        self.server_optimizer.zero_grad()

    def step(self):
        self.client_optimizer.step()
        self.server_optimizer.step()

    def train(self):
        self.client.train()
        self.server.train()

    def eval(self):
        self.client.eval()
        self.server.eval()

In [10]:
def plot_auc_and_leak(train_auc, test_auc, na_leak_auc,ma_leak_auc,median_leak_auc):
    fig, ax = plt.subplots(1,2)
    ax[0].plot(train_auc, marker='', color='skyblue', linewidth=2,label="Training AUC")
    ax[0].plot(test_auc, marker='', color='olive', linewidth=2,label="Testing AUC")
    ax[0].legend()
    ax[0].set_title("AUC")
    ax[1].set_title("Leak AUC")
    ax[1].plot(na_leak_auc, marker='', color='skyblue', linewidth=2, label="Norm Leak AUC")
    ax[1].plot(ma_leak_auc, marker='', color='olive', linewidth=2, label="Mean Leak AUC")
    ax[1].plot(median_leak_auc, marker='', color='yellow', linewidth=2, label="Median Leak AUC")
    ax[1].legend()
    plt.show()
def plot_pre_labels(prediction,labels):
    data = pd.DataFrame(columns=['label', 'plot_gradient','y_hat'])
    data['True Label'] = labels.reshape(1,-1)[0].detach().numpy()
    data['plot_gradient'] = labels.detach().numpy()
    data['prediction'] = prediction.detach().numpy()
    plt.cla()
    sns.color_palette('Set1')
    #sns.histplot(data, x='plot_gradient',y='y_hat', hue='label')
    sns.jointplot(data=data, y='True Label', x='prediction', hue='True Label')
    plt.xlabel("")
    plt.show()

In [11]:
#solve_isotropic_covariance
import math
import random
from collections import Counter
import numpy

OBJECTIVE_EPSILON = 1e-16
CONVEX_EPSILON = 1e-20
NUM_CANDIDATE = 1


def symKL_objective(lam10, lam20, lam11, lam21, u, v, d, g):
    if (lam21 + v) == 0.0 or (lam20 + u) == 0.0 or (lam11 + v) == 0.0 or (lam10 + u) == 0.0:
        return float('inf')
    objective = (d - 1) * (lam20 + u) / (lam21 + v) \
                + (d - 1) * (lam21 + v) / (lam20 + u) \
                + (lam10 + u + g) / (lam11 + v) \
                + (lam11 + v + g) / (lam10 + u)
    return objective


def symKL_objective_zero_uv(lam10, lam11, g):
    objective = (lam10 + g) / lam11 \
                + (lam11 + g) / lam10
    return objective


def solve_isotropic_covariance(u, v, d, g, p, P,
                               lam10_init=None, lam20_init=None,
                               lam11_init=None, lam21_init=None):
    """ return the solution to the optimization problem
        Args:
        u ([type]): [the coordinate variance of the negative examples]
        v ([type]): [the coordinate variance of the positive examples]
        d ([type]): [the dimension of activation to protect]
        g ([type]): [squared 2-norm of g_0 - g_1, i.e. \|g^{(0)} - g^{(1)}\|_2^2]
        P ([type]): [the power constraint value]
    """

    if u == 0.0 and v == 0.0:
        return solve_zero_uv(g=g, p=p, P=P)

    ordering = [0, 1, 2]
    random.shuffle(x=ordering)

    solutions = []
    if u <= v:
        for i in range(NUM_CANDIDATE):
            if i % 3 == ordering[0]:
                # print('a')
                if lam20_init:  # if we pass an initialization
                    lam20 = lam20_init
                    # print('here')
                else:
                    lam20 = random.random() * P / (1 - p) / d
                lam10, lam11 = None, None
                # print('lam21', lam21)
            elif i % 3 == ordering[1]:
                # print('b')
                if lam11_init:
                    lam11 = lam11_init
                else:
                    lam11 = random.random() * P / p
                lam10, lam20 = None, None
                # print('lam11', lam11)
            else:
                # print('c')
                if lam10_init:
                    lam10 = lam10_init
                else:
                    lam10 = random.random() * P / (1 - p)
                lam11, lam20 = None, None
                # print('lam10', lam10)

            solutions.append(solve_small_neg(u=u, v=v, d=d, g=g, p=p, P=P, lam10=lam10, lam11=lam11, lam20=lam20))

    else:
        for i in range(NUM_CANDIDATE):
            if i % 3 == ordering[0]:
                if lam21_init:
                    lam21 = lam21_init
                else:
                    lam21 = random.random() * P / p / d
                lam10, lam11 = None, None
                # print('lam21', lam21)
            elif i % 3 == ordering[1]:
                if lam11_init:
                    lam11 = lam11_init
                else:
                    lam11 = random.random() * P / p
                lam10, lam21 = None, None
                # print('lam11', lam11)
            else:
                if lam10_init:
                    lam10 = lam10_init
                else:
                    lam10 = random.random() * P / (1 - p)
                lam11, lam21 = None, None
                # print('lam10', lam10)

            solutions.append(solve_small_pos(u=u, v=v, d=d, g=g, p=p, P=P, lam10=lam10, lam11=lam11, lam21=lam21))

    # print(solutions)
    lam10, lam20, lam11, lam21, objective = min(solutions, key=lambda x: x[-1])

    # print('sum', p * lam11 + p*(d-1)*lam21 + (1-p) * lam10 + (1-p)*(d-1)*lam20)

    return (lam10, lam20, lam11, lam21, objective)


def solve_zero_uv(g, p, P):
    C = P

    E = math.sqrt((C + (1 - p) * g) / (C + p * g))
    tau = max((P / (p)) / (E + (1 - p) /  (p)), 0.0)
    # print('tau', tau)
    if 0 <= tau and tau <= P / (1 - p):
        # print('A')
        lam10 = tau
        lam11 = max(P /  (p) - (1 - p) * tau /  (p), 0.0)
    else:
        # print('B')
        lam10_case1, lam11_case1 = 0.0, max(P /  (p), 0.0)
        lam10_case2, lam11_case2 = max(P / (1 - p), 0), 0.0
        objective1 = symKL_objective_zero_uv(lam10=lam10_case1, lam11=lam11_case1,
                                             g=g)
        objective2 = symKL_objective_zero_uv(lam10=lam10_case2, lam11=lam11_case2,
                                             g=g)
        if objective1 < objective2:
            lam10, lam11 = lam10_case1, lam11_case1
        else:
            lam10, lam11 = lam10_case2, lam11_case2

    objective = symKL_objective_zero_uv(lam10=lam10, lam11=lam11, g=g)
    # here we subtract d = 1 because the distribution is essentially one-dimensional
    return (lam10, 0.0, lam11, 0.0, 0.5 * objective - 1)


def solve_small_neg(u, v, d, g, p, P, lam10=None, lam20=None, lam11=None):
    """[When u < v]
    """
    # some intialization to start the alternating optimization
    LAM21 = 0.0
    i = 0
    objective_value_list = []

    if lam20:
        ordering = [0, 1, 2]
    elif lam11:
        ordering = [1, 0, 2]
    else:
        ordering = [1, 2, 0]
    # print(ordering)

    while True:
        if i % 3 == ordering[0]:  # fix lam20
            D = P - (1 - p) * (d - 1) * lam20
            C = D + p * v + (1 - p) * u

            E = math.sqrt((C + (1 - p) * g) / (C + p * g))
            tau = max((D / p + v - E * u) / (E + (1 - p) / p), 0.0)
            # print('tau', tau)
            if lam20 <= tau and tau <= P / (1 - p) - (d - 1) * lam20:
                # print('A')
                lam10 = tau
                lam11 = max(D / p - (1 - p) * tau / p, 0.0)
            else:
                # print('B')
                lam10_case1, lam11_case1 = lam20, max(P / p - (1 - p) * d * lam20 / p, 0.0)
                lam10_case2, lam11_case2 = max(P / (1 - p) - (d - 1) * lam20, 0), 0.0
                objective1 = symKL_objective(lam10=lam10_case1, lam20=lam20, lam11=lam11_case1, lam21=LAM21,
                                             u=u, v=v, d=d, g=g)
                objective2 = symKL_objective(lam10=lam10_case2, lam20=lam20, lam11=lam11_case2, lam21=LAM21,
                                             u=u, v=v, d=d, g=g)
                if objective1 < objective2:
                    lam10, lam11 = lam10_case1, lam11_case1
                else:
                    lam10, lam11 = lam10_case2, lam11_case2

        elif i % 3 == ordering[1]:  # fix lam11
            D = max((P - p * lam11) / (1 - p), 0.0)
            f = lambda x: symKL_objective(lam10=D - (d - 1) * x, lam20=x, lam11=lam11, lam21=LAM21,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/v - (d-1)/(lam11+v) - (d-1)*v/((x+u)**2) + (lam11 + v + g)*(d-1)/((D-(d-1)*x+u)**2) # not numerically stable
            # f_prime = lambda x: (d-1)/v - (d-1)/(lam11+v) - (d-1)/(x+u)*(v/(x+u)) + (lam11 + v + g)/(D-(d-1)*x+u) * ((d-1)/(D-(d-1)*x+u))

            def f_prime(x):
                if x == 0.0 and u == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / v - (d - 1) / (lam11 + v) - (d - 1) / (x + u) * (v / (x + u)) + (lam11 + v + g) / (
                                D - (d - 1) * x + u) * ((d - 1) / (D - (d - 1) * x + u))

            # print('D/d', D/d)
            lam20 = convex_min_1d(xl=0.0, xr=D / d, f=f, f_prime=f_prime)
            lam10 = max(D - (d - 1) * lam20, 0.0)

        else:  # fix lam10
            D = max(P - (1 - p) * lam10, 0.0)  # avoid negative due to numerical error
            f = lambda x: symKL_objective(lam10=lam10, lam20=x, lam11=D / p - (1 - p) * (d - 1) * x / p, lam21=LAM21,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/v - (1-p)*(d-1)/(lam10 + u)/p - (d-1)*v/((x+u)**2) + (lam10+u+g)*(1-p)*(d-1)/p/((D/p - (1-p)*(d-1)*x/p + v)**2) # not numerically stable
            # f_prime = lambda x: (d-1)/v - (1-p)*(d-1)/(lam10 + u)/p - (d-1)/(x+u)*(v/(x+u)) + (lam10+u+g)/(D/p - (1-p)*(d-1)*x/p + v) * (1-p) * (d-1) / p / (D/p - (1-p)*(d-1)*x/p + v)

            def f_prime(x):
                if x == 0.0 and u == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / v - (1 - p) * (d - 1) / (lam10 + u) / p - (d - 1) / (x + u) * (v / (x + u)) + (
                                lam10 + u + g) / (D / p - (1 - p) * (d - 1) * x / p + v) * (1 - p) * (d - 1) / p / (
                                       D / p - (1 - p) * (d - 1) * x / p + v)

            # print('lam10', 'D/((1-p)*(d-1)', lam10, D/((1-p)*(d-1)))
            lam20 = convex_min_1d(xl=0.0, xr=min(D / ((1 - p) * (d - 1)), lam10), f=f, f_prime=f_prime)
            lam11 = max(D / p - (1 - p) * (d - 1) * lam20 / p, 0.0)

        if lam10 < 0 or lam20 < 0 or lam11 < 0 or LAM21 < 0:  # check to make sure no negative values
            assert False, i

        objective_value_list.append(symKL_objective(lam10=lam10, lam20=lam20, lam11=lam11, lam21=LAM21,
                                                    u=u, v=v, d=d, g=g))
        # print(i)
        # print(objective_value_list[-1])
        # print(lam10, lam20, lam11, LAM21, objective_value_list[-1])
        # print('sum', p * lam11 + p*(d-1)*LAM21 + (1-p) * lam10 + (1-p)*(d-1)*lam20)

        if (i >= 3 and objective_value_list[-4] - objective_value_list[-1] < OBJECTIVE_EPSILON) or i >= 100:
            # print(i)
            return lam10, lam20, lam11, LAM21, 0.5 * objective_value_list[-1] - d

        i += 1


def solve_small_pos(u, v, d, g, p, P, lam10=None, lam11=None, lam21=None):
    """[When u > v] lam20 = 0.0 and will not change throughout the optimization
    """
    # some intialization to start the alternating optimization
    LAM20 = 0.0
    i = 0
    objective_value_list = []
    if lam21:
        ordering = [0, 1, 2]
    elif lam11:
        ordering = [1, 0, 2]
    else:
        ordering = [1, 2, 0]
    # print(ordering)
    while True:
        if i % 3 == ordering[0]:  # fix lam21
            D = P - p * (d - 1) * lam21
            C = D + p * v + (1 - p) * u

            E = math.sqrt((C + (1 - p) * g) / (C + p * g))
            tau = max((D / p + v - E * u) / (E + (1 - p) / p), 0.0)
            # print('tau', tau)
            if 0.0 <= tau and tau <= (P - p * d * lam21) / (1 - p):
                # print('A')
                lam10 = tau
                lam11 = max(D / (p) - (1 - p) * tau / (p), 0.0)
            else:
                # print('B')
                lam10_case1, lam11_case1 = 0, max(P / p - (d - 1) * lam21, 0.0)
                lam10_case2, lam11_case2 = max((P - p * d * lam21) / (1 - p), 0.0), lam21
                objective1 = symKL_objective(lam10=lam10_case1, lam20=LAM20, lam11=lam11_case1, lam21=lam21,
                                             u=u, v=v, d=d, g=g)
                objective2 = symKL_objective(lam10=lam10_case2, lam20=LAM20, lam11=lam11_case2, lam21=lam21,
                                             u=u, v=v, d=d, g=g)
                if objective1 < objective2:
                    lam10, lam11 = lam10_case1, lam11_case1
                else:
                    lam10, lam11 = lam10_case2, lam11_case2

        elif i % 3 == ordering[1]:  # fix lam11
            D = max(P - p * lam11, 0.0)
            f = lambda x: symKL_objective(lam10=(D - p * (d - 1) * x) / (1 - p), lam20=LAM20, lam11=lam11, lam21=x,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/u - p*(d-1)/(lam11+v)/(1-p) - (d-1)*u/((x+v)**2) + (lam11 + v + g)*p*(d-1)/(1-p)/(((D - p*(d-1)*x)/(1-p) + u)**2) # not numerically stable
            # print('D', D)
            # print('P', P)
            # print('d', d)
            # print('u', u)
            # print('v', v)
            # print('g', g)
            # print('p', p)
            # print('lam11', lam11)
            # print()

            # f_prime = lambda x: (d-1)/u - p*(d-1)/(lam11+v)/(1-p) - (d-1)/(x+v)*(u/(x+v)) + (lam11 + v + g) / ((D - p*(d-1)*x)/(1-p) + u) * p * (d-1) / (1-p) /((D - p*(d-1)*x)/(1-p) + u)

            def f_prime(x):
                if x == 0.0 and v == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / u - p * (d - 1) / (lam11 + v) / (1 - p) - (d - 1) / (x + v) * (u / (x + v)) + (
                                lam11 + v + g) / ((D - p * (d - 1) * x) / (1 - p) + u) * p * (d - 1) / (1 - p) / (
                                       (D - p * (d - 1) * x) / (1 - p) + u)

            # print('lam11', 'D/p/(d-1)', lam11, D/p/(d-1))
            lam21 = convex_min_1d(xl=0.0, xr=min(D / p / (d - 1), lam11), f=f, f_prime=f_prime)
            lam10 = max((D - p * (d - 1) * lam21) / (1 - p), 0.0)

        else:  # fix lam10
            D = max((P - (1 - p) * lam10) / p, 0.0)
            f = lambda x: symKL_objective(lam10=lam10, lam20=LAM20, lam11=D - (d - 1) * x, lam21=x,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/u - (d-1)/(lam10+u) - (d-1)*u/((x+v)**2) + (lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2)

            # print('D', D)
            # print('P', P)
            # print('d', d)
            # print('u', u)
            # print('v', v)
            # print('g', g)
            # print('p', p)
            # print('lam10', lam10)
            # print()

            # f_prime = lambda x: (d-1)/u - (d-1)/(lam10+u) - (d-1)/(x+v)*(u/(x+v)) + (lam10 + u + g)/(D-(d-1)*x+v) * (d-1) / (D-(d-1)*x+v)

            def f_prime(x):
                if x == 0.0 and v == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / u - (d - 1) / (lam10 + u) - (d - 1) / (x + v) * (u / (x + v)) + (lam10 + u + g) / (
                                D - (d - 1) * x + v) * (d - 1) / (D - (d - 1) * x + v)

            # def f_prime(x):
            #     print('x', x)
            #     print('d, u, v, g', d, u, v, g)
            #     print('(d-1)/u', (d-1)/u)
            #     print('(d-1)/(lam10+u)', (d-1)/(lam10+u))
            #     print('(d-1)*u/((x+v)**2)', (d-1)*u/((x+v)**2))
            #     print('(lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2)', (lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2))

            #     return (d-1)/u - (d-1)/(lam10+u) - (d-1)*u/((x+v)**2) + (lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2)
            # print('D/d', D/d)
            lam21 = convex_min_1d(xl=0.0, xr=D / d, f=f, f_prime=f_prime)
            lam11 = max(D - (d - 1) * lam21, 0.0)

        if lam10 < 0 or LAM20 < 0 or lam11 < 0 or lam21 < 0:
            assert False, i

        objective_value_list.append(symKL_objective(lam10=lam10, lam20=LAM20, lam11=lam11, lam21=lam21,
                                                    u=u, v=v, d=d, g=g))
        # print(i)
        # print(objective_value_list[-1])
        # print(lam10, LAM20, lam11, lam21)
        # print('sum', p * lam11 + p*(d-1)*lam21 + (1-p) * lam10 + (1-p)*(d-1)*LAM20)

        if (i >= 3 and objective_value_list[-4] - objective_value_list[-1] < OBJECTIVE_EPSILON) or i >= 100:
            # print(i)
            return lam10, LAM20, lam11, lam21, 0.5 * objective_value_list[-1] - d

        i += 1


def convex_min_1d(xl, xr, f, f_prime):
    # print('xl, xr', xl, xr)
    assert xr <= 1e5
    assert xl <= xr, (xl, xr)
    # print('xl, xr', xl, xr)

    xm = (xl + xr) / 2
    # print('xl', xl, f(xl), f_prime(xl))
    # print('xr', xr, f(xr), f_prime(xr))
    # print('xm', xm, f(xm), f_prime(xm))
    # print('abs(xl - xr) <= CONVEX_EPSILON',abs(xl - xr) <= CONVEX_EPSILON,abs(xl - xr) , CONVEX_EPSILON)
    if abs(xl - xr) <= CONVEX_EPSILON:
        # print('min((f(x), x) for x in [xl, xm, xr])[1]',min((f(x), x) for x in [xl, xm, xr])[1])
        return min((f(x), x) for x in [xl, xm, xr])[1]
    if f_prime(xl) <= 0 and f_prime(xr) <= 0:
        return xr
    elif f_prime(xl) >= 0 and f_prime(xr) >= 0:
        return xl
    if f_prime(xm) > 0:
        # print('xm', xm, f(xm), f_prime(xm))
        return convex_min_1d(xl=xl, xr=xm, f=f, f_prime=f_prime)
    else:
        # print('xm', xm, f(xm), f_prime(xm))
        return convex_min_1d(xl=xm, xr=xr, f=f, f_prime=f_prime)


def small_neg_problem_string(u, v, d, g, p, P):
    return 'minimize ({2}-1)*(z + {0})/{1} + ({2}-1)*{1}/(z+{0})+(x+{0}+{3})/(y+{1}) + (y+{1}+{3})/(x+{0}) subject to x>=0, y>=0, z>=0, z<=x, {4}*y+(1-{4})*x+(1-{4})*({2}-1)*z={5}'.format(
        u, v, d, g, p, P)


def small_pos_problem_string(u, v, d, g, p, P):
    return 'minimize ({2}-1)*{0}/(z+{1}) + ({2}-1)*(z + {1})/{0} + (x+{0}+{3})/(y+{1}) + (y+{1}+{3})/(x+{0}) subject to x>=0, y>=0, z>=0, z<=y, {4}*y+(1-{4})*x+{4}*({2}-1)*z={5}'.format(
        u, v, d, g, p, P)


def zero_uv_problem_string(g, p, P):
    return 'minimize (x+{0})/y + (y+{0})/x subject to x>=0, y>=0, {1}*y+(1-{1})*x={2}'.format(g, p, P)

def KL_gradient_perturb_function_creator(Y_Train,g,p_frac='pos_frac', dynamic=False, error_prob_lower_bound=None,
                                         sumKL_threshold=None, init_scale=1.0, uv_choice='uv'):
    # print('p_frac', p_frac)
    # print('dynamic', dynamic)
    if dynamic and (error_prob_lower_bound is not None):
        '''
        if using dynamic and error_prob_lower_bound is specified, we use it to 
        determine the sumKL_threshold and overwrite what is stored in it before.
        '''
        sumKL_threshold = (2 - 4 * error_prob_lower_bound) ** 2
        # print('error_prob_lower_bound', error_prob_lower_bound)
        # print('implied sumKL_threshold', sumKL_threshold)
    # elif dynamic:
    #     print('using sumKL_threshold', sumKL_threshold)

    # print('init_scale', init_scale)
    # print('uv_choice', uv_choice)

    y = list(Y_Train.iloc[:,0])
    pos, neg = [], []
    for i in range(len(y)):
        if y[i] == 1:
            pos.append(i)
        else:
            neg.append(i)
    # print('pos', pos)
    pos_g = [g[i] for i in pos]

    pos_g_mean =numpy.mean(pos_g)
    pos_coordinate_var=numpy.var(pos_g)
    neg_g =[g[i] for i in neg]
    neg_g_mean =numpy.mean(neg_g)
    neg_coordinate_var =numpy.var(neg_g)

    avg_pos_coordinate_var = numpy.mean(pos_coordinate_var)
    avg_neg_coordinate_var = numpy.mean(neg_coordinate_var)

    g_diff = pos_g_mean - neg_g_mean
    g_diff_norm = numpy.sqrt(g_diff**2)


    if uv_choice == 'uv':
        u = float(avg_neg_coordinate_var)
        v = float(avg_pos_coordinate_var)
        # if u == 0.0:
        #     print('neg_g')
        #     print(neg_g)
        # if v == 0.0:
        #     print('pos_g')
        #     print(pos_g)

    if uv_choice == 'same':
          u = float(avg_neg_coordinate_var + avg_pos_coordinate_var) / 2.0
          v = float(avg_neg_coordinate_var + avg_pos_coordinate_var) / 2.0
    elif uv_choice == 'zero':
          u, v = 0.0, 0.0

    d = len(Y_Train)
    if p_frac == 'pos_frac':
          p = float(numpy.sum(y) / len(y))  # p is set as the fraction of positive in the batch
    else:
          p = float(p_frac)

    scale = init_scale
    lam10, lam20, lam11, lam21 = None, None, None, None
    while True:
       P = scale * g_diff_norm ** 2
            # print('g_diff_norm ** 2', g_diff_norm ** 2)
            # print('P', P)
            # print('u, v, d, p', u, v, d, p)
       lam10, lam20, lam11, lam21, sumKL = \
                    solve_isotropic_covariance(
                        u=u,
                        v=v,
                        d=d,
                        g=g_diff_norm ** 2,
                        p=p,
                        P=P,
                        lam10_init=lam10,
                        lam20_init=lam20,
                        lam11_init=lam11,
                        lam21_init=lam21)
       if not dynamic or sumKL <= sumKL_threshold:break



    perturbed_g = g
    perturbed_g += numpy.multiply(numpy.random.normal(0,1,len(y)),
                                                      y) * g_diff * (
                                           numpy.sqrt(lam11 - lam21) / g_diff_norm)
    # print(',g,perturbed_g,lam11 , lam21,g_diff',g,perturbed_g,lam11 , lam21,g_diff)

    if lam21 > 0.0:
      perturbed_g += numpy.random.normal(0,1,len(y)) * y * numpy.sqrt(
                        lam21)
      # print('0 perturbed_g lam21',lam21,perturbed_g)

                    # negative examples add noise in g1 - g0
    perturbed_g += numpy.multiply(numpy.random.normal(0,1,len(y)),
                                                      [1-y[i] for i in range(len(y))]) * g_diff * (
                                           numpy.sqrt(lam10 - lam20) / g_diff_norm)
    #print('1 lam21,g_diff_norm,perturbed_g ,',lam21,g_diff_norm,perturbed_g)


                # add spherical noise to negative examples
    if lam20 > 0.0:
          perturbed_g +=numpy.random.normal(0,1,len(y)) *[1-y[i] for i in range(len(y))] * numpy.sqrt(
                        lam20)
          # print('2 perturbed_g',perturbed_g)
    return perturbed_g

In [12]:
#SplitNN
import torch
class Client_marvell(torch.nn.Module):
    def __init__(self, client_model):
        super().__init__()
        """class that expresses the Client on SplitNN
        Args:
            client_model (torch model): client-side model
        Attributes:
            client_model (torch model): cliet-side model
            client_side_intermidiate (torch.Tensor): output of
                                                     client-side model
            grad_from_server
        """

        self.client_model = client_model
        self.client_side_intermidiate = None
        self.grad_from_server = None

    def forward(self, inputs):
        """client-side feed forward network
        Args:
            inputs (torch.Tensor): the input data
        Returns:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model which the client sent
                                                   to the server
        """

        self.client_side_intermidiate = self.client_model(inputs)
        # send intermidiate tensor to the server
        intermidiate_to_server = self.client_side_intermidiate.detach()\
            .requires_grad_()

        return intermidiate_to_server

    def client_backward(self, grad_from_server):
        """client-side back propagation
        Args:
            grad_from_server: gradient which the server send to the client
        """
        self.grad_from_server = grad_from_server
        self.client_side_intermidiate.backward(grad_from_server)

    def train(self):
        self.client_model.train()

    def eval(self):
        self.client_model.eval()


class Server_marvell(torch.nn.Module):
    def __init__(self, server_model):
        super().__init__()
        """class that expresses the Server on SplitNN
        Args:
            server_model (torch model): server-side model
        Attributes:
            server_model (torch model): server-side model
            intermidiate_to_server:
            grad_to_client
        """
        self.server_model = server_model

        self.intermidiate_to_server = None
        self.grad_to_client = None

    def forward(self, intermidiate_to_server):
        """server-side training
        Args:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model
        Returns:
            outputs (torch.Tensor): outputs of server-side model
        """
        self.intermidiate_to_server = intermidiate_to_server
        # print('intermidiate_to_server',intermidiate_to_server)
        outputs = self.server_model(intermidiate_to_server)

        return outputs

    def server_backward(self):
        self.grad_to_client = self.intermidiate_to_server.grad.clone()
        return self.grad_to_client

    def train(self):
        self.server_model.train()

    def eval(self):
        self.server_model.eval()


class SplitNN_marvell(torch.nn.Module):
    def __init__(self, client, server,
                 client_optimizer, server_optimizer,init_scale
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.client = client
        self.server = server
        self.client_optimizer = client_optimizer
        self.server_optimizer = server_optimizer
        self.grad_to_client=None
        self.grad_to_client_1=None
        self.init_scale = init_scale

        self.intermidiate_to_server = None

    def forward(self, inputs,labels):
        # execute client - feed forward network
        self.intermidiate_to_server = self.client(inputs)
        # execute server - feed forward netwoek
        # g_o=self.intermidiate_to_server1.grad.clone()
        outputs = self.server(self.intermidiate_to_server)
        self.labels=pd.DataFrame(labels.cpu().detach().numpy())
    

        return outputs,self.intermidiate_to_server

    def backward(self):
        # execute server - back propagation
        self.grad_to_client_1 = self.server.server_backward()
        # print(' self.grad_to_client_1',self.grad_to_client_1)

       
        self.grad_to_client=KL_gradient_perturb_function_creator(self.labels,
                                                            self.grad_to_client_1.cpu().detach().numpy().T[0],
                                              dynamic=False, error_prob_lower_bound=None,
                                         sumKL_threshold=None, init_scale=self.init_scale, uv_choice='uv')
        # print(' self.grad_to_client',self.grad_to_client)
        # print('grad_to_client',grad_to_client)
        self.grad_to_client=torch.Tensor(self.grad_to_client).reshape(-1,1)
        self.grad_to_client = self.grad_to_client.to(device)
        # execute client - back propagation
        # if model=='Marvell':
        # print('grad_to_client before',grad_to_client.T)
        # print('grad_to_client.detach().numpy().T[0]',grad_to_client.detach().numpy().T[0])
      
        # print('grad_to_client after',grad_to_client.T)
        
        self.client.client_backward(self.grad_to_client)

    def zero_grads(self):
        self.client_optimizer.zero_grad()
        self.server_optimizer.zero_grad()

    def step(self):
        self.client_optimizer.step()
        self.server_optimizer.step()

    def train(self):
        self.client.train()
        self.server.train()

    def eval(self):
        self.client.eval()
        self.server.eval()

In [36]:
def train_marvell(Epochs,lr = 1e-4,init_scale=1,info=True):
  model_1 = FirstNet()
  model_1 = model_1.to(device)

  model_2 = SecondNet()
  model_2 = model_2.to(device)

  model_1.double()
  model_2.double()

  opt_1 = optim.Adam(model_1.parameters(), lr=lr)
  opt_2 = optim.Adam(model_2.parameters(), lr=lr)

  BCE = nn.BCELoss()

  client = Client_marvell(model_1)
  server = Server_marvell(model_2)

  splitnn_marvell = SplitNN_marvell(client, server, opt_1, opt_2,init_scale)
  splitnn_marvell.train()

  for epoch in range(Epochs):
    epoch_loss = 0
    epoch_outputs = []
    epoch_labels = []
    epoch_outputs_test = []
    epoch_labels_test = []
    epoch_g_norm=[]
    epoch_g_mean=[]
    epoch_g_inner=[]
    epoch_g=[]
    for i, data in enumerate(train_loader):
        splitnn_marvell.zero_grads()
        inputs, labels = data
        inputs = inputs.to(device).double()
        labels = labels.to(device).double()
        
        outputs,intermidiate_to_server = splitnn_marvell(inputs,labels)
        loss = BCE(outputs, labels)
        loss.backward()
        splitnn_marvell.backward()
        splitnn_marvell.step()

        epoch_loss += loss.item() / len(train_loader.dataset)
        
        epoch_outputs.append(outputs)
        epoch_labels.append(labels)

        grad_from_server =splitnn_marvell.client.grad_from_server
        g=list(grad_from_server.cpu().detach().numpy())
        g_norm = grad_from_server.pow(2).sum(dim=1).sqrt()
        v_1=np.multiply(grad_from_server.cpu().detach().numpy(),labels.cpu().detach().numpy())
        mean_1=v_1.sum()/len(v_1[v_1!=0])
        mean_0=(grad_from_server.cpu().detach().numpy().sum()-
               v_1.sum())/len(v_1[v_1==0])

        
        g_mean=[]
        for a in g:
          if (a-mean_1)**2<(a-mean_0)**2:g_mean.append([1])
          else:g_mean.append([0])
        g_mean=torch.tensor(g_mean)
        g_inner=[]
        g=list(grad_from_server.cpu().detach().numpy())
        g_inner=[]
        for a in g:
          if a>grad_from_server.median().item():g_inner.append(1)
          else:g_inner.append(0)
        g_inner=torch.tensor(g_inner)
           
        epoch_g_norm.append(g_norm)
        epoch_g_mean.append(g_mean)
        epoch_g_inner.append(g_inner)
        epoch_g.append(grad_from_server)

        t=next(iter(test_loader))
        outputs_test,_ = splitnn_marvell(t[0].to(device),t[1].to(device))
        labels_test=t[1]
        epoch_outputs_test.append(outputs_test)
        epoch_labels_test.append(labels_test)
           
        # print('labels',torch.cat(epoch_g_norm).shape)
        # print('epoch_g_norm',torch.cat(epoch_g_norm).shape)
        # print('epoch_g_norm',torch.cat(epoch_g_norm).shape)

    # print(intermidiate_gradients)
    # print(epoch_outputs)
    # print('epoch_g_norm',torch.cat(epoch_g_norm).shape)
    # print('epoch_g_mean',torch.cat(epoch_g_mean).shape)
    # print('epoch_labels',torch.cat(epoch_labels).shape)

  
   
    train_auc=torch_auc(torch.cat(epoch_labels),
                                torch.cat(epoch_outputs))
    test_auc=torch_auc(torch.cat(epoch_labels_test),
                                torch.cat(epoch_outputs_test))
    train_tvd=0
    na_leak_auc=max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm).view(-1, 1)),
                                      1-torch_auc(torch.cat(epoch_labels), 
                                                  torch.cat(epoch_g_norm).view(-1, 1)))
    ma_leak_auc=max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean).view(-1, 1)),
                                      1-torch_auc(torch.cat(epoch_labels), 
                                                  torch.cat(epoch_g_mean).view(-1, 1)))
    cos_leak_auc=max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner).view(-1, 1)),
                                      1-torch_auc(torch.cat(epoch_labels), 
                                                  torch.cat(epoch_g_inner).view(-1, 1)))
    if info==True and (epoch%10==0 or epoch==Epochs-1):
      print('Epoch',epoch,'Training Loss',epoch_loss, 
          'Training AUC',train_auc,
             'Testing AUC',test_auc,
            "TVD",train_tvd,
            'NA Leak AUC',na_leak_auc,
          'MA Leak AUC',ma_leak_auc,
          'Median Leak AUC',cos_leak_auc
          )
  return train_auc,test_auc,train_tvd,na_leak_auc,ma_leak_auc,cos_leak_auc,splitnn_marvell

# Training Marvell with different seed

# Example 1

In [66]:
train_auc_list_marvell,test_auc_list_marvell,train_tvd_list_marvell,na_leak_auc_list_marvell,ma_leak_auc_list_marvell,cos_leak_auc_list_marvell=[],[],[],[],[],[]
best=1
init_scale = 1
Epoch = 300
random.seed(6)
for i in range(1):
  train_auc_marvell,test_auc_marvell,train_tvd_marvell,na_leak_auc_marvell,ma_leak_auc_marvell,cos_leak_auc_marvell,splitnn_marvell=train_marvell(Epochs=Epoch,init_scale=init_scale,info=True)
  
  # train_auc_marvell,test_auc_marvell,na_leak_auc_marvell,ma_leak_auc_marvell,cos_leak_auc_marvell,splitnn_marvell=train_marvell(Epochs=300,info=True)
  train_auc_list_marvell.append(train_auc_marvell)
  test_auc_list_marvell.append(test_auc_marvell)
  train_tvd_list_marvell.append(train_tvd_marvell)
  na_leak_auc_list_marvell.append(na_leak_auc_marvell)
  ma_leak_auc_list_marvell.append(ma_leak_auc_marvell)
  cos_leak_auc_list_marvell.append(cos_leak_auc_marvell)
  if na_leak_auc_marvell<best:
    best=na_leak_auc_marvell
    marvell_model=splitnn_marvell
print('Mean Training AUC',np.mean(train_auc_list_marvell),np.std(train_auc_list_marvell))
print('Mean Testing AUC',np.mean(test_auc_list_marvell),np.std(test_auc_list_marvell))
print('Mean TVD',np.mean(train_tvd_list_marvell),np.std(train_tvd_list_marvell))
print('Mean NA Leak AUC',np.mean(na_leak_auc_list_marvell),np.std(na_leak_auc_list_marvell))
print('Mean MA Leak AUC',np.mean(ma_leak_auc_list_marvell),np.std(ma_leak_auc_list_marvell))
print('Mean Median Leak AUC',np.mean(cos_leak_auc_list_marvell),np.std(cos_leak_auc_list_marvell))

Epoch 0 Training Loss 0.0027035087819332654 Training AUC 0.593296604429708 Testing AUC 0.6101009092006382 TVD 0 NA Leak AUC 0.5513755142888913 MA Leak AUC 0.6928621250871891 Median Leak AUC 0.6865147640083702
Epoch 10 Training Loss 0.002704580776869042 Training AUC 0.602326253752919 Testing AUC 0.6052194630568635 TVD 0 NA Leak AUC 0.5533249092728688 MA Leak AUC 0.6976052080911416 Median Leak AUC 0.6930248779353638
Epoch 20 Training Loss 0.0027045369930593366 Training AUC 0.6115002577762503 Testing AUC 0.603371771987627 TVD 0 NA Leak AUC 0.5463813268906119 MA Leak AUC 0.7015345268542199 Median Leak AUC 0.6988839804696582
Epoch 30 Training Loss 0.002705956143624192 Training AUC 0.6211218826764251 Testing AUC 0.6348327491784914 TVD 0 NA Leak AUC 0.5494281410794255 MA Leak AUC 0.700418507323878 Median Leak AUC 0.6995349918623576
Epoch 40 Training Loss 0.002702846873044507 Training AUC 0.6299922161681308 Testing AUC 0.6371564695685199 TVD 0 NA Leak AUC 0.5298031802513066 MA Leak AUC 0.67730

In [67]:
print('Mean Training AUC',np.mean(train_auc_list_marvell),np.std(train_auc_list_marvell))
print('Mean Testing AUC',np.mean(test_auc_list_marvell),np.std(test_auc_list_marvell))
print('Mean TVD',np.mean(train_tvd_list_marvell),np.std(train_tvd_list_marvell))
print('Mean NA Leak AUC',np.mean(na_leak_auc_list_marvell),np.std(na_leak_auc_list_marvell))
print('Mean MA Leak AUC',np.mean(ma_leak_auc_list_marvell),np.std(ma_leak_auc_list_marvell))
print('Mean Median Leak AUC',np.mean(cos_leak_auc_list_marvell),np.std(cos_leak_auc_list_marvell))

Mean Training AUC 0.7947019398926438 0.0
Mean Testing AUC 0.8226798592948894 0.0
Mean TVD 0.0 0.0
Mean NA Leak AUC 0.5373476340183779 0.0
Mean MA Leak AUC 0.6882817949314114 0.0
Mean Median Leak AUC 0.6904208323645664 0.0


# Exapmle 2

In [58]:
train_auc_list_marvell,test_auc_list_marvell,train_tvd_list_marvell,na_leak_auc_list_marvell,ma_leak_auc_list_marvell,cos_leak_auc_list_marvell=[],[],[],[],[],[]
best=1
init_scale = 1
Epoch = 300
random.seed(9)
for i in range(1):
  train_auc_marvell,test_auc_marvell,train_tvd_marvell,na_leak_auc_marvell,ma_leak_auc_marvell,cos_leak_auc_marvell,splitnn_marvell=train_marvell(Epochs=Epoch,init_scale=init_scale,info=True)
  
  # train_auc_marvell,test_auc_marvell,na_leak_auc_marvell,ma_leak_auc_marvell,cos_leak_auc_marvell,splitnn_marvell=train_marvell(Epochs=300,info=True)
  train_auc_list_marvell.append(train_auc_marvell)
  test_auc_list_marvell.append(test_auc_marvell)
  train_tvd_list_marvell.append(train_tvd_marvell)
  na_leak_auc_list_marvell.append(na_leak_auc_marvell)
  ma_leak_auc_list_marvell.append(ma_leak_auc_marvell)
  cos_leak_auc_list_marvell.append(cos_leak_auc_marvell)
  if na_leak_auc_marvell<best:
    best=na_leak_auc_marvell
    marvell_model=splitnn_marvell
print('Mean Training AUC',np.mean(train_auc_list_marvell),np.std(train_auc_list_marvell))
print('Mean Testing AUC',np.mean(test_auc_list_marvell),np.std(test_auc_list_marvell))
print('Mean TVD',np.mean(train_tvd_list_marvell),np.std(train_tvd_list_marvell))
print('Mean NA Leak AUC',np.mean(na_leak_auc_list_marvell),np.std(na_leak_auc_list_marvell))
print('Mean MA Leak AUC',np.mean(ma_leak_auc_list_marvell),np.std(ma_leak_auc_list_marvell))
print('Mean Median Leak AUC',np.mean(cos_leak_auc_list_marvell),np.std(cos_leak_auc_list_marvell))

Epoch 0 Training Loss 0.003032412798307199 Training AUC 0.387173255966762 Testing AUC 0.3948984698984699 TVD 0 NA Leak AUC 0.5149267612183214 MA Leak AUC 0.6817949314112998 Median Leak AUC 0.6780516159032783
Epoch 10 Training Loss 0.0030381953120327313 Training AUC 0.39303619987262817 Testing AUC 0.3808996433115641 TVD 0 NA Leak AUC 0.5212927226226458 MA Leak AUC 0.690165077888863 Median Leak AUC 0.6878167867937689
Epoch 20 Training Loss 0.003031600536110635 Training AUC 0.3993853805485075 Testing AUC 0.4058447583460654 TVD 0 NA Leak AUC 0.5071459620108569 MA Leak AUC 0.6968611950709137 Median Leak AUC 0.6865147640083702
Epoch 30 Training Loss 0.00303301541310979 Training AUC 0.40564762492039264 Testing AUC 0.3941555746710129 TVD 0 NA Leak AUC 0.5072616075129142 MA Leak AUC 0.6805394094396652 Median Leak AUC 0.6780516159032783
Epoch 40 Training Loss 0.00302583701018108 Training AUC 0.4120728243179038 Testing AUC 0.40651166249505344 TVD 0 NA Leak AUC 0.5225660362099815 MA Leak AUC 0.707

In [59]:
print('Mean Training AUC',np.mean(train_auc_list_marvell),np.std(train_auc_list_marvell))
print('Mean Testing AUC',np.mean(test_auc_list_marvell),np.std(test_auc_list_marvell))
print('Mean TVD',np.mean(train_tvd_list_marvell),np.std(train_tvd_list_marvell))
print('Mean NA Leak AUC',np.mean(na_leak_auc_list_marvell),np.std(na_leak_auc_list_marvell))
print('Mean MA Leak AUC',np.mean(ma_leak_auc_list_marvell),np.std(ma_leak_auc_list_marvell))
print('Mean Median Leak AUC',np.mean(cos_leak_auc_list_marvell),np.std(cos_leak_auc_list_marvell))

Mean Training AUC 0.5853118081740343 0.0
Mean Testing AUC 0.603217949885456 0.0
Mean TVD 0.0 0.0
Mean NA Leak AUC 0.5096541754698098 0.0
Mean MA Leak AUC 0.6786793768890955 0.0
Mean Median Leak AUC 0.6754475703324808 0.0
