In [1]:
import torch
from torch import nn
from abc import ABC, abstractmethod
from collections import defaultdict

import numpy as np
from torch.nn import init
import torch.nn.functional as F

from models.base import BranchModel
from models.costs import module_cost

import logging

from torchsummary import summary

# Define the path to the saved model

from torch.optim.lr_scheduler import StepLR, MultiStepLR
from tqdm import tqdm

import onnx
import onnx_tf
import tensorflow as tf

# from base.evaluators import standard_eval, branches_eval, binary_eval, \
#     binary_statistics
# from models.base import BranchModel
# from utils import get_device
# from copy import deepcopy

2023-03-12 16:08:39.251610: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-12 16:08:40.132761: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-12 16:08:40.132830: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
device

device(type='cpu')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class EarlyExitBlock(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, exit_threshold):
        super(EarlyExitBlock, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=1)
        self.exit_threshold = exit_threshold

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        pred = self.softmax(x)
        if torch.max(pred, 1)[0] > self.exit_threshold:
            return pred
        else:
            return None

class EarlyExitNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, exit_thresholds):
        super(EarlyExitNetwork, self).__init__()
        self.exit_thresholds = exit_thresholds
        self.block1 = EarlyExitBlock(input_size, hidden_size, output_size, exit_thresholds[0])
        self.block2 = EarlyExitBlock(hidden_size, hidden_size, output_size, exit_thresholds[1])
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.block1(x)
        if x is not None:
            return x
        x = self.block2(x)
        if x is not None:
            return x
        x = self.fc(x)
        return self.softmax(x)

# Initialize the network and optimizer
input_size = 784
hidden_size = 256
output_size = 10
exit_thresholds = [0.9, 0.8]
net = EarlyExitNetwork(input_size, hidden_size, output_size, exit_thresholds)
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Load the dataset and create data loaders
train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

# Train the network
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.view(-1, input_size)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = nn.functional.nll_loss(torch.log(outputs), labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch %d: loss=%.3f' % (epoch+1, running_loss/len(train_loader)))


In [None]:
def trainer(model: BranchModel, 
            predictors: nn.Module,
            optimizer,
            train_loader,
            epochs,
            device,
            scheduler=None,
            early_stopping=None,
            test_loader=None, eval_loader=None):

    scores = []
    mean_losses = []

    best_model = model.state_dict()
    best_model_i = 0
    best_eval_score = -1

    model.to(device)

    if early_stopping is not None:
        early_stopping.reset()

    model.train()
    bar = tqdm(range(epochs), leave=True)
    for epoch in bar:
        model.train()
        losses = []
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            pred = model(x)[-1]
            pred = predictors[-1].logits(pred)

            loss = nn.functional.cross_entropy(pred, y, reduction='none')
            losses.extend(loss.tolist())
            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        mean_loss = sum(losses) / len(losses)
        mean_losses.append(mean_loss)
        
        if early_stopping is not None:
            r = early_stopping.step(eval_scores) if eval_loader is not None \
                else early_stopping.step(mean_loss)

            if r < 0:
                break
            elif r > 0:
                best_model = deepcopy(model.state_dict())
                best_predictors = deepcopy(predictors.state_dict())

                best_model_i = epoch
        else:
            if (eval_scores is not None and eval_scores >= best_eval_score) \
                    or eval_scores is None:

                if eval_scores is not None:
                    best_eval_score = eval_scores

                best_model = deepcopy(model.state_dict())
                best_predictors = deepcopy(predictors.state_dict())

                best_model_i = epoch
                train_scores = standard_eval(model=model,
                                     dataset_loader=train_loader,
                                     classifier=predictors[-1])

        test_scores = standard_eval(model=model,
                                    dataset_loader=test_loader,
                                    classifier=predictors[-1])

        bar.set_postfix(
            {'Train score': train_scores, 'Test score': test_scores,
             'Eval score': eval_scores if eval_scores != 0 else 0,
             'Mean loss': mean_loss})

        scores.append((train_scores, eval_scores, test_scores))

## FL


In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img


if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    if args.all_clients: 
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]
    for iter in range(args.epochs):
        loss_locals = []
        if not args.all_clients:
            w_locals = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

    # testing
    net_glob.eval()
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

In [None]:
ls

## MIA Attack

https://github.com/spring-epfl/mia 
https://github.com/tensorflow/privacy/tree/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack

with tensorflow


In [30]:
# load global cnn model from FL

from models.Nets import MLP, CNNMnist, CNNCifar

import argparse

# Create an ArgumentParser object
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
args = parser.parse_args([])

# set the value of num_classes manually
args.num_classes = 10

model = CNNCifar(args)
model.load_state_dict(torch.load('results/models/cnn'))


<All keys matched successfully>

In [42]:
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             456
         MaxPool2d-2            [-1, 6, 14, 14]               0
            Conv2d-3           [-1, 16, 10, 10]           2,416
         MaxPool2d-4             [-1, 16, 5, 5]               0
            Linear-5                  [-1, 120]          48,120
            Linear-6                   [-1, 84]          10,164
            Linear-7                   [-1, 10]             850
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.31
----------------------------------------------------------------


In [34]:
x = torch.randn(64, 3, 32, 32, requires_grad=True)

# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "cnn.onnx")

In [52]:
prepare(onnx_model).members()

AttributeError: 'TensorflowRep' object has no attribute 'members'

In [43]:
from onnx_tf.backend import prepare

# Load the ONNX model
onnx_model = onnx.load("cnn.onnx")

# Convert the model to TensorFlow format
tf_model = prepare(onnx_model) # Import the ONNX model to Tensorflow

# # Save the model in protobuf format using write_graph()
# tf.io.write_graph(tf_model.graph.as_graph_def(), 'results/models', 'tf_cnn.pb', as_text=False)

# # Save the model in saved_model format using saved_model.save()
# tf.saved_model.save(tf_model, 'results/models/tf_cnn')



In [54]:
pip install tensorflow_federated

Collecting tensorflow_federated
  Downloading tensorflow_federated-0.48.0-py2.py3-none-any.whl (42.8 MB)
[K     |████████████████████████████████| 42.8 MB 19.1 MB/s eta 0:00:01
[?25hCollecting jaxlib==0.3.14
  Downloading jaxlib-0.3.14-cp39-none-manylinux2014_x86_64.whl (71.3 MB)
[K     |████████████████████████████████| 71.3 MB 101.0 MB/s eta 0:00:01
[?25hCollecting grpcio~=1.46
  Downloading grpcio-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
[K     |████████████████████████████████| 4.8 MB 104.8 MB/s eta 0:00:01     |███████████████▏                | 2.3 MB 104.8 MB/s eta 0:00:01
[?25hCollecting jax==0.3.14
  Downloading jax-0.3.14.tar.gz (990 kB)
[K     |████████████████████████████████| 990 kB 90.7 MB/s eta 0:00:01
Collecting semantic-version~=2.6
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)
Collecting pytype==2022.12.15
  Downloading pytype-2022.12.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[K    

Collecting protobuf<3.20,>=3.9.2
  Downloading protobuf-3.19.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 95.3 MB/s eta 0:00:01
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.3.14-py3-none-any.whl size=1147584 sha256=38faacf5f841538e9980579793e00e01d7b8a0b9cff4bc7532b54412c230beec
  Stored in directory: /home/yashuo/.cache/pip/wheels/32/21/2b/29f2d0dba28673825c67ce8451e44b07ca7bbf8e68964a82db
Successfully built jax
Installing collected packages: cachetools, typing-extensions, protobuf, grpcio, typing-inspect, flatbuffers, tabulate, pydot, ninja, libcst, jinja2, importlab, tensorflow-privacy, tensorflow-model-optimization, tensorflow-compression, semantic-version, pytype, portpicker, jaxlib, jax, farmhashpy, tensorflow-federated
  Attempting uninstall: cachetools
    Found existing installation: cachetools 4.2.2
    Uninst

    Uninstalling grpcio-1.42.0:
      Successfully uninstalled grpcio-1.42.0
  Attempting uninstall: flatbuffers
    Found existing installation: flatbuffers 23.1.21
    Uninstalling flatbuffers-23.1.21:
      Successfully uninstalled flatbuffers-23.1.21
  Attempting uninstall: tabulate
    Found existing installation: tabulate 0.8.9
    Uninstalling tabulate-0.8.9:
      Successfully uninstalled tabulate-0.8.9
  Attempting uninstall: jinja2
    Found existing installation: Jinja2 2.11.3
    Uninstalling Jinja2-2.11.3:
      Successfully uninstalled Jinja2-2.11.3
  Attempting uninstall: tensorflow-privacy
    Found existing installation: tensorflow-privacy 0.8.7
    Uninstalling tensorflow-privacy-0.8.7:
      Successfully uninstalled tensorflow-privacy-0.8.7
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
anaconda-project 0.10.2 requires ruamel-yaml, whic

In [55]:
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow.keras.datasets import cifar10

# Load and prepare the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.batch(32).shuffle(10000)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_data = test_data.batch(32)

# Define the TensorFlow model
def create_keras_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

# Set up the TFF simulation environment
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(keras_model, input_spec=train_data.element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(model_fn)

# Train the model using the TFF simulation
state = trainer.initialize()
for i in range(10):
    state, metrics = trainer.next(state, [train_data] * 10)
    print('Round {}: loss={}, accuracy={}'.format(i, metrics.loss, metrics.sparse_categorical_accuracy))


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


2023-03-02 12:29:44.561205: W tensorflow/tsl/framework/cpu_allocator_impl.cc:82] Allocation of 153600000 exceeds 10% of free system memory.
2023-03-02 12:29:44.666274: W tensorflow/tsl/framework/cpu_allocator_impl.cc:82] Allocation of 153600000 exceeds 10% of free system memory.


AttributeError: module 'tensorflow_federated.python.learning' has no attribute 'build_federated_averaging_process'

In [4]:
import numpy as np
from typing import Tuple
from scipy import special
from sklearn import metrics
import os
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

# Set verbosity.
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from sklearn.exceptions import ConvergenceWarning

import warnings
warnings.simplefilter(action="ignore", category=ConvergenceWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
import tensorflow_privacy

import argparse
# Create an ArgumentParser object
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
args = parser.parse_args([])


# set the value of num_classes manually
args.num_classes = 10

In [6]:
dataset = 'cifar10'
num_classes = 10
activation = 'relu'
num_conv = 3

batch_size=50
epochs_per_report = 2
total_epochs = 50

lr = 0.001

train_ds = tfds.as_numpy(
    tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))
test_ds = tfds.as_numpy(
    tfds.load(dataset, split=tfds.Split.TEST, batch_size=-1))
x_train = train_ds['image'].astype('float32') / 255.
y_train_indices = train_ds['label'][:, np.newaxis]
x_test = test_ds['image'].astype('float32') / 255.
y_test_indices = test_ds['label'][:, np.newaxis]

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)
y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)

print('x_train', np.shape(x_train))
print('y_train', np.shape(y_train))

input_shape = x_train.shape[1:]

2023-03-07 19:42:52.184416: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


x_train (50000, 32, 32, 3)
y_train (50000, 10)


## test branchy net 

In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow_datasets as tfds
import numpy as np
from typing import Tuple
from scipy import special
from sklearn import metrics
import os
import matplotlib.pyplot as plt
# Set verbosity.
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from sklearn.exceptions import ConvergenceWarning

import warnings
warnings.simplefilter(action="ignore", category=ConvergenceWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
import tensorflow_privacy

import argparse

In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow_datasets as tfds
import numpy as np
from typing import Tuple
from scipy import special
from sklearn import metrics
import os
import matplotlib.pyplot as plt
# Set verbosity.
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from sklearn.exceptions import ConvergenceWarning

import warnings
warnings.simplefilter(action="ignore", category=ConvergenceWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
import tensorflow_privacy

import argparse

class BranchyAlexNet(tf.keras.Model):
    def __init__(self, num_classes):
        super(BranchyAlexNet, self).__init__()
        self.num_classes = num_classes
        
        self.conv1 = Conv2D(96, (11,11), strides=(4,4), activation='relu', padding='valid')
        self.pool1 = MaxPooling2D((3,3), strides=(2,2))
        self.conv2 = Conv2D(256, (5,5), strides=(1,1), activation='relu', padding='same')
        self.pool2 = MaxPooling2D((3,3), strides=(2,2))
        self.conv3 = Conv2D(384, (3,3), strides=(1,1), activation='relu', padding='same')
        self.conv4 = Conv2D(384, (3,3), strides=(1,1), activation='relu', padding='same')
        self.conv5 = Conv2D(256, (3,3), strides=(1,1), activation='relu', padding='same')
        self.pool3 = MaxPooling2D((3,3), strides=(2,2))
        self.flatten = Flatten()
        self.fc1 = Dense(4096, activation='relu')
        self.fc2 = Dense(4096, activation='relu')
        self.fc3 = Dense(num_classes)
        
        self.branch1_fc = Dense(num_classes, name='branch1_fc')
        self.branch2_fc = Dense(num_classes, name='branch2_fc')

    def call(self, x, training=False):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.pool3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        final_output = self.fc3(x)
        branch1_output = self.branch1_fc(x)
        branch2_output = self.branch2_fc(x)
        
        if training:
            return final_output, branch1_output, branch2_output
        else:
            return final_output


parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
args = parser.parse_args([])

# set the value of num_classes manually
args.num_classes = 10

class branchy_CNNCifar(tf.keras.Model):
    def __init__(self, args):
        super(branchy_CNNCifar, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(6, (5, 5), activation='relu')
        self.pool = tf.keras.layers.MaxPooling2D((2, 2))
        self.conv2 = tf.keras.layers.Conv2D(16, (5, 5), activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(120, activation='relu')
        self.fc2 = tf.keras.layers.Dense(84, activation='relu')
        self.fc3 = tf.keras.layers.Dense(args.num_classes)

        # Define the branches
        self.branches = []

        branch1 = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),
                                        tf.keras.layers.Dense(args.num_classes)])

        self.branches.append(branch1)

    def call(self, x):
        x = tf.cast(x, dtype=tf.float32)  # cast input to float32
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = self.flatten(x)
        branch = self.branches[0]
        branch1_output = branch(x)
        
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return branch1_output, x

dataset = 'cifar10'
num_classes = 10
activation = 'relu'
num_conv = 3

batch_size=50
epochs_per_report = 1
total_epochs = 5

lr = 0.001

# Load the CIFAR-10 dataset

print('Loading the dataset.')
train_ds = tfds.as_numpy(
    tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))
test_ds = tfds.as_numpy(
    tfds.load(dataset, split=tfds.Split.TEST, batch_size=-1))
x_train = train_ds['image'].astype('float32') / 255.
y_train_indices = train_ds['label'][:, np.newaxis]
x_test = test_ds['image'].astype('float32') / 255.
y_test_indices = test_ds['label'][:, np.newaxis]

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)
y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)

print('x_train', np.shape(x_train))
print('y_train', np.shape(y_train))

input_shape = x_train.shape[1:]

assert x_train.shape[0] % batch_size == 0, "The tensorflow_privacy optimizer doesn't handle partial batches"


# Create the BranchyNet model
model = branchy_CNNCifar(args)

# Define your loss function
def cross_entropy_loss(y_true, branch_output, final_output):

    loss_early = tf.keras.losses.categorical_crossentropy(y_true, branch_output)
    loss_final = tf.keras.losses.categorical_crossentropy(y_true, final_output)
        # Compute the total loss
    total_loss = loss_early*0.4 + loss_final*0.6
    return total_loss

# Define your optimizer
optimizer = tf.keras.optimizers.Adam()

# Define your accuracy metric
metric = tf.keras.metrics.SparseCategoricalAccuracy()

# Define the early exit threshold
threshold = 0.7
# Define the training loop

def train_step(inputs, labels):
    # Initialize the gradients
    with tf.GradientTape() as tape:
        # Forward pass
        branch_output, final_output = model(inputs)

        # Compute the total loss
        total_loss =  cross_entropy_loss(labels, branch_output, final_output)
        # print(total_loss)
        # Compute the gradients
        grads = tape.gradient(total_loss, model.trainable_variables)
        # Update the model
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
        # Convert the one-hot encoded labels to integer labels
        labels = tf.argmax(labels, axis=1)
        # Convert the integer labels to one-hot encoded labels
        labels = tf.one_hot(labels, depth=10)
        branch_output = tf.argmax(branch_output, axis =1)
        final_output = tf.argmax(final_output, axis = 1)

        metric.update_state(labels, final_output)
        return final_output, total_loss
        
all_reports = []
epochs_per_report = 1
# callback = PrivacyMetrics(epochs_per_report, "branchy_cnn", model)
# Train the model
for epoch in range(total_epochs):
    total_loss = 0
    # Shuffle the training data
    permutation = np.random.permutation(len(x_train))
    x_train_shuffled = x_train[permutation]
    y_train_shuffled = y_train[permutation]

    for batch in range(0, len(x_train), batch_size):
        # Get the batch
        x_batch = x_train[batch:batch+batch_size]
        y_batch = y_train[batch:batch+batch_size]


        # Reset the accuracy metric
        metric.reset_states()
        # Call the training step function
        outputs, losses = train_step(x_batch, y_batch)
        
        # Compute the total loss
        for loss in losses:
            total_loss += loss
            
        # Print the results

    #### This should be a validation/test for the loss value need midify
    # Collect the required logs in a dictionary
    print("metric.result().numpy()", metric.result().numpy())
    logs = {'loss': total_loss.numpy()/len(x_train), 'val_accuracy': metric.result().numpy()}
    # Call the on_epoch_end method with the logs dictionary
    # callback.on_epoch_end(epoch, logs=logs)
    print("Epoch:", epoch, "Loss:", total_loss.numpy()/len(x_train), "Accuracy:", metric.result().numpy())

Loading the dataset.
x_train (50000, 32, 32, 3)
y_train (50000, 10)
metric.result().numpy() 0.9
Epoch: 0 Loss: 7.733941875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 1 Loss: 7.644310625 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 2 Loss: 7.6871825 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 3 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 4 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 5 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 6 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 7 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 8 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 9 Loss: 7.72679875 Accuracy: 0.9
metric.result().numpy() 0.9
Epoch: 10 Loss: 7.72679875 Accuracy: 0.9


KeyboardInterrupt: 