# Example Notebook

In this notebook we show how to use the library to implement computionally-efficient DP-SGD with JAX.

# 0. Setup (skip until 1. if you don't need the details)

In [1]:
import argparse
import os
import math
import time
import warnings
import jax

import numpy as np

## 0.1 Enviroment variables

In [2]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".90"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

jax.clear_caches()

## 0.2 Use GPU or CPU?

In [3]:
USE_GPU = False

## 0.3 Arguments

Here you can change the value of the arguments by changing the default values.

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
parser.add_argument("--num_steps", default=10, type=int, help="Number of steps")
parser.add_argument("--logical_bs", default=100, type=int, help="Logical batch size")
parser.add_argument("--clipping_norm", default=1, type=float, help="max grad norm")

parser.add_argument("--target_epsilon", default=8, type=float, help="target epsilon")
parser.add_argument("--target_delta", default=1e-5, type=float, help="target delta")

parser.add_argument("--physical_bs", default=2, type=int, help="Physical Batch Size")
parser.add_argument("--accountant", default="pld", type=str, help="The privacy accountant for DP training.")

parser.add_argument("--seed", default=1234, type=int)
args = parser.parse_args(args=[])


print("Used args:", args, flush=True)


Used args: Namespace(lr=0.001, num_steps=10, logical_bs=100, clipping_norm=1, target_epsilon=8, target_delta=1e-05, physical_bs=2, accountant='pld', seed=1234)


# 1. Setting up dataset, model and DP accounting
We show you how to setup the dataset, model and how the DP accounting works.

## 1.1 Dataset

We load the dataset from [Hugging Face](https://huggingface.co/) but the only important thing is to have the data available as arrays. Hugging Face supports this nicely but there might be other or even better ways to achieve this.

In [5]:
from jaxdpopt.data import load_from_huggingface

train_images, train_labels, test_images, test_labels = load_from_huggingface(
    "uoft-cs/cifar10", cache_dir=None, feature_name="img"
)
ORIG_IMAGE_DIMENSION, RESIZED_IMAGE_DIMENSION = 32, 32
N_CHANNELS = 3
ORIG_IMAGE_SHAPE = (N_CHANNELS, ORIG_IMAGE_DIMENSION, ORIG_IMAGE_DIMENSION)
train_images = train_images[train_labels < 2].transpose(0, 3, 1, 2)
train_labels = train_labels[train_labels < 2]
test_images = test_images[test_labels < 2].transpose(0, 3, 1, 2)
test_labels = test_labels[test_labels < 2]


num_classes = 2
dataset_size = len(train_labels)

  from .autonotebook import tqdm as notebook_tqdm


## 1.2 Model

We create a `flax.training.train_state.TrainState` and load pre-trained weights from [Hugging Face](https://huggingface.co/) using the `create_train_state` function that we provide in `src.models.py`. In this particular example, we load the weights of a simple Conv_Net.

In [6]:
from jaxdpopt.models import create_train_state
from collections import namedtuple

optimizer_config = namedtuple("Config", ["learning_rate"])
optimizer_config.learning_rate = args.lr

state = create_train_state(
    model_name="small",
    num_classes=num_classes,
    image_dimension=RESIZED_IMAGE_DIMENSION,
    optimizer_config=optimizer_config,
)

load model name small


## 1.3 DP accounting

First, we compute the `subsampling_ratio` based on the `dataset_size` and the (expected) `logical_bs`. Then we compute the required DP-SGD `noise_std` based on the `subsampling_ratio` and the `num_steps` for a particular pair of `target_epsilon` and `target_delta` using a privacy accountant. At the moment the Privacy Loss Distributions (PLDs) and RDP accounting from the google [dp_accounting](https://github.com/google/differential-privacy/tree/main/python/dp_accounting) library are supported. 

*Note: You can also use the accounting tooling of other libraries such as the PyTorch based [opacus](https://github.com/pytorch/opacus).*

In [7]:
from jaxdpopt.dp_accounting_utils import calculate_noise
if dataset_size * args.target_delta > 1.0:
    warnings.warn("Your delta might be too high.")

subsampling_ratio = 1 / math.ceil(dataset_size / args.logical_bs)

noise_std = calculate_noise(
        sample_rate=subsampling_ratio,
        target_epsilon=args.target_epsilon,
        target_delta=args.target_delta,
        steps=args.num_steps,
        accountant=args.accountant,
)

# 2. Function to process one physical batch

First we define the function that computes per example gradients (`compute_per_example_gradients_physical_batch`) and clips them (`clip_and_accumulate_physical_batch`). This function can be jit compiled and then used in the full training loop later (see 3.).

In [8]:
from jaxdpopt.jax_mask_efficient import (
    compute_per_example_gradients_physical_batch,
    add_trees,
    clip_physical_batch,
    accumulate_physical_batch,
    CrossEntropyLoss
)

loss_fn = CrossEntropyLoss(state=state,num_classes=num_classes,resizer_fn=lambda x:x)


@jax.jit
def process_physical_batch(t, params):
    (
        state,
        accumulated_clipped_grads,
        logical_batch_X,
        logical_batch_y,
        masks,
    ) = params
    # slice
    start_idx = t * args.physical_bs

    start_shape = (start_idx,0,) + (0,)*len(ORIG_IMAGE_SHAPE)

    batch_shape = (args.physical_bs,1,) + ORIG_IMAGE_SHAPE

    pb = jax.lax.dynamic_slice(
        logical_batch_X,
        start_shape,
        batch_shape,
    )
    yb = jax.lax.dynamic_slice(logical_batch_y, (start_idx,), (args.physical_bs,))
    mask = jax.lax.dynamic_slice(masks, (start_idx,), (args.physical_bs,))

    # compute grads and clip
    per_example_gradients = compute_per_example_gradients_physical_batch(state, pb, yb, loss_fn)
    clipped_grads_from_pb = clip_physical_batch(per_example_gradients, args.clipping_norm)
    sum_of_clipped_grads_from_pb = accumulate_physical_batch(clipped_grads_from_pb, mask)
    accumulated_clipped_grads = add_trees(accumulated_clipped_grads, sum_of_clipped_grads_from_pb)

    return (
        state,
        accumulated_clipped_grads,
        logical_batch_X,
        logical_batch_y,
        masks,
    )

# 3. Full training loop

The below cell executes the main training loop. It consists of the following parts at every step:

- Poission sampling of the logical batch size (`poisson_sample_logical_batch_size`): This gives us the logical batch size using Poisson subsampling.
- Rounding up of the logical batch size so that there are full physical batches (`setup_physical_batches`): This rounds up the logical batch size so that it is divisible in full physical batches
- Padding of the logical batches (`get_padded_logical_batch`): Here we load the actual images and labels.
- Computation of the per sample gradients (`jax.lax.fori_loop` using the previously defined `process_physical_batch`): This efficiently computes the per-example gradients of the logical batch.
- Addition of noise (`add_Gaussian_noise`): Add the required noise to the accumulated gradients of the logical batch. 
- Update of the model (`update_model`): Apply the gradient update to the model weights.

At the end of a step the following things are executed:
- Computation of the throughput: Compute the number of processed examples divided by the time spent.
- Evaluate the model (`model_evaluation`): Compute the test accuracy.
- Compute the spent privacy budget (`compute_epsilon`): Compute the spent privacy budget using a privacy accountant.

In [9]:
from jaxdpopt.dp_accounting_utils import compute_epsilon
from jaxdpopt.jax_mask_efficient import (
    get_padded_logical_batch,
    model_evaluation,
    add_Gaussian_noise,
    poisson_sample_logical_batch_size,
    setup_physical_batches,
    update_model,
)

times = []
logical_batch_sizes = []

for t in range(args.num_steps):
    sampling_rng = jax.random.key(t + 1)
    batch_rng, binomial_rng, noise_rng = jax.random.split(sampling_rng, 3)

    #######
    # poisson subsample
    actual_batch_size = dataset_size * subsampling_ratio

    # determine padded_logical_bs so that there are full physical batches
    # and create appropriate masks to mask out unnessary elements later
    masks, n_physical_batches = setup_physical_batches(
        actual_logical_batch_size=actual_batch_size,
        physical_bs=args.physical_bs,
    )

    # get random padded logical batches that are slighly larger actual batch size
    padded_logical_batch_X, padded_logical_batch_y = get_padded_logical_batch(
        batch_rng=batch_rng,
        padded_logical_batch_size=len(masks),
        train_X=train_images,
        train_y=train_labels,
    )

    padded_logical_batch_X = padded_logical_batch_X.reshape(-1, 1, 3, ORIG_IMAGE_DIMENSION, ORIG_IMAGE_DIMENSION)

    # cast to GPU
    if USE_GPU:
        padded_logical_batch_X = jax.device_put(padded_logical_batch_X, jax.devices("gpu")[0])
        padded_logical_batch_y = jax.device_put(padded_logical_batch_y, jax.devices("gpu")[0])
        masks = jax.device_put(masks, jax.devices("gpu")[0])

    print("##### Starting gradient accumulation #####", flush=True)
    ### gradient accumulation
    params = state.params

    accumulated_clipped_grads0 = jax.tree.map(lambda x: 0.0 * x, params)

    start = time.time()

    # Main loop
    _, accumulated_clipped_grads, *_ = jax.lax.fori_loop(
        0,
        n_physical_batches,
        process_physical_batch,
        (
            state,
            accumulated_clipped_grads0,
            padded_logical_batch_X,
            padded_logical_batch_y,
            masks,
        ),
    )
    noisy_grad = add_Gaussian_noise(noise_rng, accumulated_clipped_grads, noise_std, args.clipping_norm)

    # update
    state = jax.block_until_ready(update_model(state, noisy_grad))

    end = time.time()
    duration = end - start

    times.append(duration)
    logical_batch_sizes.append(actual_batch_size)

    print(f"throughput at iteration {t}: {actual_batch_size / duration}", flush=True)

    acc_iter = model_evaluation(
        state, test_images, test_labels, batch_size=10, orig_img_shape=ORIG_IMAGE_SHAPE, use_gpu=USE_GPU
    )
    print(f"accuracy at iteration {t}: {acc_iter}", flush=True)

    # Compute privacy guarantees
    epsilon, delta = compute_epsilon(
        noise_multiplier=noise_std,
        sample_rate=subsampling_ratio,
        steps=t + 1,
        target_delta=args.target_delta,
        accountant=args.accountant,
    )
    privacy_results = {"accountant": args.accountant, "epsilon": epsilon, "delta": delta}
    print(privacy_results, flush=True)

##### Starting gradient accumulation #####
throughput at iteration 0: 30.08940151467706
accuracy at iteration 0: 0.5320000052452087
{'accountant': 'pld', 'epsilon': 5.957165897649239, 'delta': 9.999999999661162e-06}
##### Starting gradient accumulation #####
throughput at iteration 1: 71.89429801609089
accuracy at iteration 1: 0.48399999737739563
{'accountant': 'pld', 'epsilon': 6.508521466852355, 'delta': 9.999999999527538e-06}
##### Starting gradient accumulation #####
throughput at iteration 2: 528.8039860407883
accuracy at iteration 2: 0.5929999947547913
{'accountant': 'pld', 'epsilon': 6.8372466859444465, 'delta': 9.999999999457034e-06}
##### Starting gradient accumulation #####
throughput at iteration 3: 480.38226219136476
accuracy at iteration 3: 0.6455000042915344
{'accountant': 'pld', 'epsilon': 7.0804913276558, 'delta': 9.999999999435226e-06}
##### Starting gradient accumulation #####
throughput at iteration 4: 520.2385928528814
accuracy at iteration 4: 0.6754999756813049
{'a

## 3.1 Final Model evaluation
Here we computate of the throughput (num processed examples/time spent) and final test accuracy (`model_evaluation`).

In [10]:
acc_last = model_evaluation(state, test_images, test_labels, batch_size=10, use_gpu=USE_GPU, orig_img_shape=ORIG_IMAGE_SHAPE)

print("times \n", times, flush=True)

print("batch sizes \n ", logical_batch_sizes, flush=True)

print("accuracy at end of training", acc_last, flush=True)
thr = np.mean(np.array(logical_batch_sizes) / np.array(times))
print("throughput", thr)

times 
 [3.3234293460845947, 1.3909308910369873, 0.18910598754882812, 0.20816755294799805, 0.19221949577331543, 0.1946735382080078, 0.19500303268432617, 0.18951869010925293, 0.27461886405944824, 0.38489627838134766]
batch sizes 
  [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]
accuracy at end of training 0.728
throughput 380.95053185737765
