Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Same uncertainty value for every input #567

Closed
zedoul opened this issue Feb 21, 2023 · 3 comments
Closed

Same uncertainty value for every input #567

zedoul opened this issue Feb 21, 2023 · 3 comments
Assignees

Comments

@zedoul
Copy link

zedoul commented Feb 21, 2023

🐛 Bug

Hi,

First of all, thank you for a great open source product! Here I think I found a potential issue on Opacus, so I would like to report the bug.

A short summary: An opacus-based classfication model generates the same uncertainty value for every input, when the number of feature is equal to 4, on Wine dataset.

My hypothesis at this moment is that this error is able to be reproduced when 1) a classification model 2) a tabular data and 3) the number of X is 4. The last condition may sound very strange and I try to reproduce the bug with another conditions, but that's where my conclusion is ended up right now. I am pretty sure that the number of features is not something related to a root cause and there must be another datasets that may cause the same error but without the third condition. However, at least I hope my sample code is something that is reproduciable on your local. Also, the data works fine on TensorFlow privacy equivalent implementation.

Here is my reproduced result. I tried to use Colab as you recommended but I unfortunately do not know how to use it. Instead, I copied and pasted the code and attached dataset. The dataset that I used is the kaggle Wine data. But, as I said, I faced the error on another data too.

My sample code is not really great, e.g., you may find some hyperparamters could be further improved, etc. Regardless, I expect that any output from Opacus-enhanced version of PyTorch should produce different uncertainty values if the inputs are different, especially when a normal PyTorch, without Opacus, generates different uncertainty values.

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as torch_optim
import torch.nn.functional as F

import opacus

batch_size = 50
epochs = 10
lr = 0.0001
n_class = 10

df = pd.read_csv("wine.csv")


X = df.iloc[:, 5:9]
y = df.iloc[:, -1]
X = X.to_numpy()
y = y.to_numpy()


assert(X.shape[1] == 4)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

train_priv_X, test_pub_X, train_priv_Y, test_pub_Y = train_test_split(X, y, test_size=0.2, stratify=y)

train_priv_Y = LabelEncoder().fit_transform(train_priv_Y)
test_pub_Y = LabelEncoder().fit_transform(test_pub_Y)

from torch.utils.data import Dataset, DataLoader
import numpy as np

class PowerDataset(Dataset):
    def __init__(self, X, Y):
        X = X.copy()
        self.X = X.astype(np.float32) #numerical columns
        self.y = Y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


train_ds = PowerDataset(train_priv_X, train_priv_Y)
test_ds = PowerDataset(test_pub_X, test_pub_Y)
train_dl = DataLoader(train_ds, batch_size=batch_size,shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size,shuffle=True)

class PowerModel(nn.Module):
    def __init__(self, n_cont, n_class):
        super().__init__()
        self.n_cont = n_cont
        self.n_class = n_class
        self.lin1 = nn.Linear(self.n_cont, 20)
        self.lin2 = nn.Linear(20, 20)
        self.lin3 = nn.Linear(20, self.n_class)
        self.bn1 = nn.BatchNorm1d(self.n_cont)
        self.bn2 = nn.BatchNorm1d(20)
        self.bn3 = nn.BatchNorm1d(20)

    def forward(self, x_cont):
        x = self.bn1(x_cont)
        x = F.relu(self.lin1(x))
        x = self.bn2(x)
        x = F.relu(self.lin2(x))
        x = self.bn3(x)
        x = self.lin3(x)
        return x

n_col = train_priv_X.shape[1]
model = PowerModel(n_col, n_class)
optim = torch_optim.Adam(model.parameters(), lr=lr)

def train_model(model, optim, train_dl):
    model.train()
    total = 0
    sum_loss = 0
    for x, y in train_dl:
        batch = y.shape[0]
        output = model(x)
        loss = F.cross_entropy(output, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        total += batch
        sum_loss += batch*(loss.item())
    return sum_loss/total

def val_loss(model, valid_dl):
    model.eval()
    total = 0
    sum_loss = 0
    correct = 0
    for x, y in valid_dl:
        current_batch_size = y.shape[0]
        out = model(x)
        loss = F.cross_entropy(out, y)
        sum_loss += current_batch_size*(loss.item())
        total += current_batch_size
        pred = torch.max(out, 1)[1]
        correct += (pred == y).float().sum().item()
    return sum_loss/total, correct/total

def train_loop(model, epochs, optim, train_dl, test_dl):
    for i in range(epochs):
        loss = train_model(model, optim, train_dl)
        print(i, "training loss: ", loss)
        vloss, accr = val_loss(model, train_dl)
        print("train: valid loss %.3f and accuracy %.3f" % (vloss, accr))
        vloss, accr = val_loss(model, test_dl)
        print("test: valid loss %.3f and accuracy %.3f" % (vloss, accr))

from sklearn import preprocessing

def calc_uncertainty(model, target_ds):
    target_dl = DataLoader(target_ds, batch_size=batch_size,shuffle=True)
    preds = []
    with torch.no_grad():
        for x,y in target_dl:
            out = model(x)
            prob = F.softmax(out, dim=1)
            preds.append(prob)

    final_probs = [item for sublist in preds for item in sublist]
    return final_probs

from opacus.validators import ModuleValidator

model = PowerModel(n_col, n_class)
m = ModuleValidator.fix(model)
optim = torch_optim.Adam(m.parameters(), lr=lr)

from opacus import PrivacyEngine
from opacus.validators import ModuleValidator

epsilon = 0.5
delta = 0.001
max_grad_norm = 1.0

privacy_engine = PrivacyEngine(secure_mode = True)
m, optim, train_dl = privacy_engine.make_private_with_epsilon(
        module=m,
        optimizer=optim,
        data_loader=train_dl,
        target_epsilon=epsilon,
        target_delta=delta,
        epochs = epochs,
        max_grad_norm=max_grad_norm)

train_loop(m, epochs, optim, train_dl, test_dl)

test_uncertain = calc_uncertainty(m, test_ds)
print(test_uncertain)

wine.csv

Please reproduce using our template Colab and post here the link

To Reproduce

  1. Copy and paste the code and the attached dataset in the same directory
  2. Establish environment
  3. python ./opacus_test.py

Expected behavior

A classification model based on Opacus-enhanced PyTorch classification model should generate different uncertainty values for different input.

Environment

Here is a requirements.txt file for the environment.

pandas
numpy
scikit-learn
torch==1.8.1
torchcsprng==0.2.1
torchvision
argparse
jupyter
scipy
opacus
  • PyTorch Version (e.g., 1.0): 1.8.1
  • OS (e.g., Linux): Ubuntu 5.15.0-56-generic
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8.5
  • CUDA/cuDNN version: NVIDIA-SMI 470.161.03 Driver Version: 470.161.03 CUDA Version: 11.4
@alexandresablayrolles
Copy link
Contributor

Thanks for raising this issue. I am not sure this really depends on Opacus: when the model is trained, it is no longer an "Opacus model" but just a regular "Pytorch model". I would recommend to try Opacus with various levels of noise and/or clipping value to see if that changes the uncertainty prediction.

@alexandresablayrolles alexandresablayrolles self-assigned this Feb 23, 2023
@ffuuugor
Copy link
Contributor

Hi @zedoul
As Alex mentioned, this issue isn't related to DP-SGD training process, but rather with the model itself. Although, Opacus do play a part here, since it switches BatchNorm to GroupNorm.

So what happens is:

  1. ModuleValidator replaces BatchNorm1d(4) layer with GroupNorm(4,4)
  2. When applied to batched 1-dimensional input, GroupNorm(4,4) effectively erases the data, as it applies the normalization over the array of size 1 for every feature.
  3. Then, model is not really training, since every input is erased to an array of zeros after the first normalization layer.

You can fix that from your side by manually replacing BatchNorms with GroupNorms with more sensible number of groups, or with other type of normalization.

There's one thing that we can improve in opacus - let user specify number of groups when applying .fix() method, now it's defaulted to gcd(32, module.num_features)

facebook-github-bot pushed a commit that referenced this issue Apr 3, 2023
Summary:
## Problem
As highlighted by #567, end user have little control over how exactly `ModelValidator.fix()` deals with BatchNorms.
For example, our approach to choosing number of groups is `gcd(32, module.num_features)`, which is fine for most cases, but can break a model occasionally (see #567 for a demonstration)

## Solution
Pass `num_groups` as kwarg, allowing clients to control the behaviour

Pull Request resolved: #580

Reviewed By: Anonymani

Differential Revision: D44545935

Pulled By: ffuuugor

fbshipit-source-id: fe006ecba9c4714b523ae8455fd68b7b5c96f75a
@zedoul
Copy link
Author

zedoul commented Apr 12, 2023

@alexandresablayrolles @ffuuugor I appreciate the assistance from both of you. After implementing the updated model code, it functions correctly. Finally I can compare Opacus and TF Privacy in better way. Once again, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants