In [1]:
import numpy as np
import utils
import torch
import pandas as pd
import gpytorch

from tqdm.notebook import tqdm
from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel



In [2]:
!nvidia-smi

Thu Feb 16 03:55:09 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.86       Driver Version: 470.86       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN RTX    Off  | 00000000:3B:00.0 Off |                  N/A |
| 41%   32C    P8    13W / 280W |      3MiB / 24220MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    Off  | 00000000:86:00.0 Off |                  Off |
| 38%   61C    P8    17W / 300W |   1225MiB / 48685MiB |      0%      Default |
|       

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda', index=0)

In [5]:
class DirichletGPModel(ExactGP):
    def __init__(self, train_x, train_y, likelihood, num_classes):
        super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
        self.covar_module = ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size((num_classes,))),
            batch_shape=torch.Size((num_classes,)),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train(train_x, train_y, device_idx=0):
    device = torch.device(f'cuda:{device_idx}' if torch.cuda.is_available() else 'cpu')

    likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True).cuda()
    model = DirichletGPModel(train_x, likelihood.transformed_targets, likelihood, num_classes=likelihood.num_classes).cuda()
    
    model.to(device)
    likelihood.to(device)
    
    training_iterations = 50
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    model.train()
    likelihood.train()
    
    for i in tqdm(range(training_iterations)):
        # Zero backprop gradients
        optimizer.zero_grad()
        # Get output from model
        output = model(train_x)
        # Calc loss and backprop derivatives
        loss = -mll(output, train_y).sum()
        loss.backward()
        optimizer.step()
    
    return model, likelihood

In [6]:
dfx, dfy, _ = utils.get_dataset('adult_income', return_dataframe=True)

In [7]:
from sklearn.preprocessing import minmax_scale
cols = dfx.columns[:4]

In [8]:
cols

Index(['age', 'capital_gain', 'capital_loss', 'hours_per_week'], dtype='object')

In [9]:
dfx_1 = dfx.loc[dfx.gender == 1]
dfx_0 = dfx.loc[dfx.gender == 0]

In [10]:
dfx_1

Unnamed: 0,age,capital_gain,capital_loss,hours_per_week,workclass_Private,workclass_Local-gov,workclass_Self-emp-not-inc,workclass_Federal-gov,workclass_State-gov,workclass_Self-emp-inc,...,native_country_Jamaica,native_country_Ecuador,native_country_Yugoslavia,native_country_Hungary,native_country_Hong,native_country_Greece,native_country_Trinadad&Tobago,native_country_Outlying-US(Guam-USVI-etc),native_country_France,native_country_Holand-Netherlands
8,24,0,0,40,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
12,26,0,0,39,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
17,43,0,0,30,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
18,37,0,0,20,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
21,34,0,0,35,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48827,37,0,0,40,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
48830,43,0,0,40,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
48837,27,0,0,38,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
48839,58,0,0,40,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [11]:
dfx_1[cols] = minmax_scale(dfx_1[cols])
dfx_0[cols] = minmax_scale(dfx_0[cols])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx_1[cols] = minmax_scale(dfx_1[cols])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx_0[cols] = minmax_scale(dfx_0[cols])


In [12]:
df_x1 = dfx_1.sample(1000)
df_x0 = dfx_0.sample(1000)

warm_start_y1 = torch.from_numpy(dfy.loc[df_x1.index].values).to(device)
warm_start_y0 = torch.from_numpy(dfy.loc[df_x0.index].values).to(device)

warm_start_x1 = torch.from_numpy(df_x1.values).float().to(device)
warm_start_x0 = torch.from_numpy(df_x0.values).float().to(device)

In [13]:
model0, likelihood0 = train(warm_start_x0, warm_start_y0)
model1, likelihood1 = train(warm_start_x1, warm_start_y1)

model0.to(device)
likelihood0.to(device)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

DirichletClassificationLikelihood(
  (noise_covar): FixedGaussianNoise()
  (second_noise_covar): HomoskedasticNoise(
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)

In [14]:
model0.eval(), model1.eval(), likelihood0.eval(), likelihood1.eval();

In [15]:
from vae_models import VanillaVAE as vae

In [16]:
vae0 = vae()
vae0.load_state_dict(torch.load("checkpoints/adult_income/vae_xA=0/best.pt"))
vae0.eval()
vae0.to(device);

In [17]:
vae1 = vae()
vae1.load_state_dict(torch.load("checkpoints/adult_income/vae_xA=1/best.pt"))
vae1.eval()
vae1.to(device);

In [18]:
from blackbox_models import BlackBox

In [19]:
blackbox = BlackBox('Logistic', 102, 1)
blackbox.load_state_dict(torch.load("/mnt/infonas/data/eeshaan/fairness/EE492/checkpoints/adult_income/blackbox/Logistic/best.pt"))
blackbox.eval()
blackbox.to(device);

In [20]:
lambda_reg = 1.0
candidates = []
neg_queried = warm_start_x0.clone()
neg_labels = warm_start_y0.clone().unsqueeze(1)

for epoch_outer in tqdm(range(1, 4001)):
    x0_random = torch.normal(0.,1.,size=(1,102), dtype=torch.float32, requires_grad=True)
    optimizer0 = torch.optim.AdamW((x0_random,), lr=1.0)
    best_loss = 10e5
    count = 0
    losses = []
    for epoch in range(1,100):
        optimizer0.zero_grad()
        x0_samples = utils.postprocess(
            vae0.sample(x0_random.to(device), 100, device, 
                        **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False}).squeeze(1),
            'adult_income'
        )
        obj0 = likelihood0(model0(x0_samples.to(device))).variance.sum(axis=0).mean() - lambda_reg * ((x0_random.to(device) - x0_samples)**2).mean()
#         print(obj0.shape)
        loss = -obj0
        loss.backward()
        optimizer0.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 5:
            break
    x0_query = dfx_0.iloc[np.argmin(np.linalg.norm(np.array(dfx_0) - x0_random.detach().clone().numpy()))]
    candidates.append(torch.Tensor([x0_query]).float().to(device))
    
    if epoch_outer % 100 == 0:
        new_vals = torch.concatenate(candidates)
        new_queries,_, _ = vae0(new_vals.to(device), **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False})
        new_labels = blackbox(new_queries)
        neg_queried = torch.concatenate([neg_queried, new_queries])
        neg_labels = torch.concatenate([neg_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
        print(neg_labels.shape, neg_labels.dtype)
        print(neg_queried.shape, neg_queried.dtype)
        model0, likelihood0 = train(neg_queried.detach().clone(),neg_labels.flatten())
        model0.eval()
        likelihood0.eval()
        candidates = []

  0%|          | 0/4000 [00:00<?, ?it/s]

torch.Size([1100, 1]) torch.int64
torch.Size([1100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1200, 1]) torch.int64
torch.Size([1200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1300, 1]) torch.int64
torch.Size([1300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1400, 1]) torch.int64
torch.Size([1400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1500, 1]) torch.int64
torch.Size([1500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1600, 1]) torch.int64
torch.Size([1600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1700, 1]) torch.int64
torch.Size([1700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1800, 1]) torch.int64
torch.Size([1800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1900, 1]) torch.int64
torch.Size([1900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2000, 1]) torch.int64
torch.Size([2000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2100, 1]) torch.int64
torch.Size([2100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2200, 1]) torch.int64
torch.Size([2200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2300, 1]) torch.int64
torch.Size([2300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2400, 1]) torch.int64
torch.Size([2400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2500, 1]) torch.int64
torch.Size([2500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2600, 1]) torch.int64
torch.Size([2600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2700, 1]) torch.int64
torch.Size([2700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2800, 1]) torch.int64
torch.Size([2800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2900, 1]) torch.int64
torch.Size([2900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3000, 1]) torch.int64
torch.Size([3000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3100, 1]) torch.int64
torch.Size([3100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3200, 1]) torch.int64
torch.Size([3200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3300, 1]) torch.int64
torch.Size([3300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3400, 1]) torch.int64
torch.Size([3400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3500, 1]) torch.int64
torch.Size([3500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3600, 1]) torch.int64
torch.Size([3600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3700, 1]) torch.int64
torch.Size([3700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3800, 1]) torch.int64
torch.Size([3800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3900, 1]) torch.int64
torch.Size([3900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4000, 1]) torch.int64
torch.Size([4000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4100, 1]) torch.int64
torch.Size([4100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4200, 1]) torch.int64
torch.Size([4200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4300, 1]) torch.int64
torch.Size([4300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4400, 1]) torch.int64
torch.Size([4400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4500, 1]) torch.int64
torch.Size([4500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4600, 1]) torch.int64
torch.Size([4600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4700, 1]) torch.int64
torch.Size([4700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4800, 1]) torch.int64
torch.Size([4800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4900, 1]) torch.int64
torch.Size([4900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([5000, 1]) torch.int64
torch.Size([5000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

In [21]:
candidates = []
pos_queried = warm_start_x1.clone()
pos_labels = warm_start_y1.clone().unsqueeze(1)
print(pos_labels.shape, pos_labels.dtype)
print(pos_queried.shape, pos_queried.dtype)

for epoch_outer in tqdm(range(1, 4001)):
    x1_random = torch.normal(0.,1.,size=(1,102), dtype=torch.float32, requires_grad=True)
    optimizer1 = torch.optim.AdamW((x1_random,), lr=1.0)
    best_loss = 10e5
    count = 0
    losses = []
    for epoch in range(1,100):
        optimizer1.zero_grad()
        x1_samples = utils.postprocess(
            vae1.sample(x1_random.to(device), 100, device, 
                        **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False}).squeeze(1),
            'adult_income'
        )
        obj1 = likelihood1(model1(x1_samples.to(device))).variance.sum(axis=0).mean() - lambda_reg * ((x1_random.to(device) - x1_samples)**2).mean()
#         print(obj0.shape)
        loss = -obj1
        loss.backward()
        optimizer1.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 5:
            break
    x1_query = dfx_1.iloc[np.argmin(np.linalg.norm(np.array(dfx_1) - x1_random.detach().clone().numpy()))]
    candidates.append(torch.Tensor([x1_query]).float().to(device))
    
    if epoch_outer % 100 == 0:
        new_vals = torch.concatenate(candidates)
        new_queries,_ , _ = vae1(new_vals.to(device), **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False})
        new_labels = blackbox(new_queries)
        pos_queried = torch.concatenate([pos_queried, new_queries])
        pos_labels = torch.concatenate([pos_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
        model1, likelihood1 = train(pos_queried.detach().clone(),pos_labels.flatten())
        model1.eval()
        likelihood1.eval()
        candidates = []

torch.Size([1000, 1]) torch.int64
torch.Size([1000, 102]) torch.float32


  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [22]:
pos_labels.shape

torch.Size([5000, 1])

In [23]:
new_labels.shape

torch.Size([100, 1])

In [24]:
# parity in data
np.abs(dfy[dfx[dfx.gender == 0].index].mean() - dfy[dfx[dfx.gender == 1].index].mean())

0.19911019753072282

In [25]:
bb_input0 = torch.from_numpy(dfx_0.values).float().to(device)
bb_input1 = torch.from_numpy(dfx_1.values).float().to(device)

In [26]:
y0 = blackbox(bb_input0)
y1 = blackbox(bb_input1)

In [27]:
y0_ = torch.round(y0)
y1_ = torch.round(y1)

In [28]:
y0_.sum()

tensor(7659., device='cuda:0', grad_fn=<SumBackward0>)

In [29]:
torch.abs(y0_.mean() - y1_.mean())

tensor(0.1869, device='cuda:0', grad_fn=<AbsBackward0>)

In [35]:
dp_value = torch.abs(pos_labels.squeeze().float().mean() - neg_labels.squeeze().float().mean())
dp_value

tensor(0.0406, device='cuda:0')

In [36]:
pos_labels.float().mean()

tensor(0.0196, device='cuda:0')

In [37]:
neg_labels.float().mean()

tensor(0.0602, device='cuda:0')

In [39]:
torch.save(dp_value, 'results/adult_income/dp_value_vanilla_vae_reg.pt')