In [None]:
import os, sys
os.chdir("../..")
sys.path.append(os.getcwd())

import breaching
import numpy as np
from pprint import pprint
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
import copy
from tqdm import tqdm

import scipy.stats as stats
import scipy.integrate as integrate
import numpy as np
from tqdm import tqdm

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

print(torch.__version__, torchvision.__version__)

In [None]:
sensitivity_path = os.getcwd() + "/classification/sensitivity/"
print(sensitivity_path)

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

In [None]:
def flat_tensor_list(grads):
    Shapes_orgianl = []
    flat_tensor_lists = torch.tensor([], device=grads[0].device)
    for i in grads:
        Shapes_orgianl.append(i.shape)
        flat_tensor_lists = torch.concat((flat_tensor_lists, i.view(-1)), dim = 0)
    return flat_tensor_lists

In [None]:
def min_max_normalize(l: list) -> list:
    l_min = min(l)
    l_max = max(l)
    l_norm = [(x - l_min) / (l_max - l_min) for x in l]
    return l_norm

def scale_to_100(l: list) -> list:
    l = min_max_normalize(l)
    original_sum = sum(l)
    l_scaled = [x / original_sum * 100 for x in l]
    return l_scaled

def plot_layer_sens_mean(sens: list, model_name: str):
    sens_layer_mean = [layer_sens.mean().item() for layer_sens in sens]
    sens_layer_mean_scale = scale_to_100(sens_layer_mean)

    plt.figure()
    plt.bar(np.arange(1, len(sens_layer_mean_scale)+1), sens_layer_mean_scale)
    plt.xlabel("Layer Index")
    plt.ylabel("Mean Sensitivity Ratio (%)")
    plt.title(model_name)
    plt.show()


In [None]:
def get_data_point(user, setup):
    data_point = dict()
    for data_block in user.dataloader:
        data = dict()
        for key in data_block:
            data[key] = data_block[key].to(device=setup["device"])
        data_key = "input_ids" if "input_ids" in data.keys() else "inputs"
        data_point = {key: val[0 : 1] for key, val in data.items()}
        data_point[data_key] = (
            data_point[data_key] + user.generator_input.sample(data_point[data_key].shape)
            if user.generator_input is not None
            else data_point[data_key]
        )
        break
    return data_point

In [None]:
def get_sensitivity(model, loss_fn, data_point, device, discrete_grad=False, grad_on_x=False, use_jacobian=False):
    model.eval()
    labels = data_point["labels"].to(torch.float32)
    gt_label = torch.Tensor([labels[0]]).long().to(device)
    print("gt_label:", gt_label.item())
    
    if grad_on_x:
        data_point["inputs"].requires_grad = True
        outputs = model(**data_point)

        gt_one_hot_label = torch.zeros_like(outputs).to(device)
        gt_one_hot_label[0, gt_label.item()] = 1
        print("One hot label:", torch.argmax(gt_one_hot_label, dim=-1).item())
        if use_jacobian:
            def grad_w(data_in):
                data_point_input = {
                    "inputs": data_in, 
                    "labels": data_point["labels"]
                }
                out = model(**data_point_input)
                l = loss_fn(out, gt_one_hot_label)
                l.backward(create_graph=True)
                dl_dw = [param.grad for param in model.parameters()]
                return tuple(dl_dw)
            
            d2l_dwdx = torch.autograd.functional.jacobian(
                grad_w,
                data_point["inputs"]
            )
            return d2l_dwdx
        else:
            loss = loss_fn(outputs, gt_one_hot_label)
            loss.backward(create_graph=True)

            dl_dw = [param.grad for param in model.parameters()]

            d2l_dwdx = []
            for idx, layer_grad in enumerate(dl_dw):
                print(f"Layer {idx+1}")
                sens_layer = torch.zeros_like(layer_grad.view(-1))
                cnt = 0

                layer_grad_flatten = layer_grad.view(-1)
                for j in tqdm(range(len(layer_grad_flatten))):
                    data_point["inputs"].grad.data.zero_()
                    layer_grad_flatten[j].backward(retain_graph=True)
                    sens_layer[cnt] = data_point["inputs"].grad.mean().clone().detach()
                    cnt += 1
                d2l_dwdx.append(sens_layer)
            return d2l_dwdx
    else:
        outputs = model(**data_point)
        
        gt_one_hot_label = torch.zeros_like(outputs).to(device)
        gt_one_hot_label_minus = torch.zeros_like(outputs).to(device)
        gt_one_hot_label_plus = torch.zeros_like(outputs).to(device)

        gt_one_hot_label[0, gt_label.item()] = 1
        gt_one_hot_label_minus[0, gt_label.item() - 1] = 1
        gt_one_hot_label_plus[0, gt_label.item() + 1] = 1
        print("One hot label:", torch.argmax(gt_one_hot_label, dim=-1).item())
        
        def get_grad(labels):
            l = loss_fn(outputs, labels)
            l.backward(create_graph=True)
            print("Loss:", l)
            grad_list = [param.grad.clone().detach() for param in model.parameters()]
            model.zero_grad()
            return grad_list
        
        dl_dw = get_grad(gt_one_hot_label)
        print("len(dl_dw):", len(flat_tensor_list(dl_dw)))
        dl_dw_minus = get_grad(gt_one_hot_label_minus)
        dl_dw_plus = get_grad(gt_one_hot_label_plus)

        assert len(dl_dw) == len(dl_dw_minus) == len(dl_dw_plus)

        num_layer = len(dl_dw)
        d2l_dwdy = []

        if discrete_grad:
            for i in range(num_layer):
                grad_minus = dl_dw_minus[i]
                grad = dl_dw[i]
                grad_plus = dl_dw_plus[i]
                d2l_dwdy.append(torch.max(torch.abs(grad_minus - grad), torch.abs(grad - grad_plus)))
        else:
            d2l_dwdy = dl_dw
        
        return d2l_dwdy

In [None]:
def get_mean_sens(cfg_config, model_name, device, num_user=5, discrete_grad=False, grad_on_x=False, use_jacobian=False):
    sens_mean = []
    for i in tqdm(range(num_user)):
        cfg_config.case.user.user_idx = i+1 # From which user?
        setup = dict(device=device, dtype=getattr(torch, cfg_config.case.impl.dtype))
        user, server, model, loss_fn = breaching.cases.construct_case(cfg_config.case, setup)
        model.to(device=setup["device"])

        data_point = get_data_point(user, setup)

        # plt.imshow(data_point["inputs"][0].permute(1, 2, 0).cpu())
        # plt.title("Ground truth image")
        # plt.show()

        sens_single = get_sensitivity(model, loss_fn,
                                      copy.deepcopy(data_point),
                                      setup["device"], 
                                      discrete_grad,
                                      grad_on_x,
                                      use_jacobian)
    
        if i == 0:
            sens_mean = sens_single
        else:
            sens_mean = [sens_mean[j] + sens_single[j] / num_user for j in range(len(sens_mean))]
        sens_mean_path = sensitivity_path + model_name + "_mean_sens"
        if discrete_grad:
            sens_mean_path += "_discrete"
        sens_mean_path += ".pt"
    torch.save(sens_mean, sens_mean_path)
    print(sens_mean_path)
    return sens_mean, sens_mean_path


## LeNet (CIFAR)

In [None]:
cfg = breaching.get_config(overrides=["case=6_large_batch_cifar"])

torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.data.partition="balanced" # 100 unique CIFAR-100 images
cfg.case.user.user_idx = 0
cfg.case.model='lenet100'

cfg.case.user.provide_labels = False
cfg.attack.label_strategy = "yin" # also works here, as labels are unique

# Total variation regularization needs to be smaller on CIFAR-10:
cfg.attack.regularization.total_variation.scale = 5e-4

In [None]:
lenet_sens_mean, lenet_mean_path = get_mean_sens(cfg, "lenet", 
                                                       torch.device('cpu'), 
                                                       num_user=5, 
                                                       discrete_grad=False)
plot_layer_sens_mean(lenet_sens_mean, "LeNet")

## CNN (CIFAR)

In [None]:
cfg = breaching.get_config(overrides=["case=6_large_batch_cifar"])

torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.data.partition="balanced" # 100 unique CIFAR-100 images
cfg.case.user.user_idx = 0
cfg.case.model='CNN_FedAvg'

cfg.case.user.provide_labels = False
cfg.attack.label_strategy = "yin" # also works here, as labels are unique

# Total variation regularization needs to be smaller on CIFAR-10:
cfg.attack.regularization.total_variation.scale = 5e-4

In [None]:
CNN_FedAvg_sens_mean, CNN_FedAvg_mean_path = get_mean_sens(cfg, "CNN_FedAvg", 
                                                       torch.device('cpu'), 
                                                       num_user=5, 
                                                       discrete_grad=False)
plot_layer_sens_mean(CNN_FedAvg_sens_mean, "CNN_FedAvg")

## ResNet18 (CIFAR)

In [None]:
cfg = breaching.get_config(overrides=["case=6_large_batch_cifar"])

torch.backends.cudnn.benchmark = cfg.case.impl.benchmark

cfg.case.data.partition="unique-class"
cfg.case.model='resnet18'

cfg.case.user.provide_labels = False
cfg.attack.label_strategy = "yin" # also works here, as labels are unique

# Total variation regularization needs to be smaller on CIFAR-10:
cfg.attack.regularization.total_variation.scale = 5e-4

In [None]:
resnet18_cifar_sens_mean, resnet18_cifar_mean_path = get_mean_sens(cfg, "resnet18_cifar", 
                                                       torch.device('cpu'), 
                                                       num_user=5, 
                                                       discrete_grad=False)
plot_layer_sens_mean(resnet18_cifar_sens_mean, "ResNet18")