In [1]:
import numpy as np
import utils
import torch
import pandas as pd
import gpytorch

from tqdm import tqdm
from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi

Wed Feb 15 18:01:56 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:21:00.0 Off |                  Off |
| 30%   44C    P8    26W / 300W |      3MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:22:00.0 Off |                  Off |
| 56%   80C    P2   257W / 300W |  22058MiB / 48685MiB |     80%      Default |
|       

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda', index=0)

In [5]:
class DirichletGPModel(ExactGP):
    def __init__(self, train_x, train_y, likelihood, num_classes):
        super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
        self.covar_module = ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size((num_classes,))),
            batch_shape=torch.Size((num_classes,)),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train(train_x, train_y, device_idx=0):
    device = torch.device(f'cuda:{device_idx}' if torch.cuda.is_available() else 'cpu')

    likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True).cuda()
    model = DirichletGPModel(train_x, likelihood.transformed_targets, likelihood, num_classes=likelihood.num_classes).cuda()
    
    model.to(device)
    likelihood.to(device)
    
    training_iterations = 50
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    model.train()
    likelihood.train()
    
    for i in tqdm(range(training_iterations)):
        # Zero backprop gradients
        optimizer.zero_grad()
        # Get output from model
        output = model(train_x.to(device))
        # Calc loss and backprop derivatives
        loss = -mll(output, train_y.to(device)).sum()
        loss.backward()
        optimizer.step()
    
    return model, likelihood

In [6]:
dfx, dfy, _ = utils.get_dataset('adult_income', return_dataframe=True)

In [7]:
from sklearn.preprocessing import minmax_scale
cols = dfx.columns[:4]

In [8]:
cols

Index(['age', 'capital_gain', 'capital_loss', 'hours_per_week'], dtype='object')

In [9]:
dfx_1 = dfx.loc[dfx.gender == 1]
dfx_0 = dfx.loc[dfx.gender == 0]

In [10]:
dfx_1

Unnamed: 0,age,capital_gain,capital_loss,hours_per_week,workclass_Private,workclass_Local-gov,workclass_Self-emp-not-inc,workclass_Federal-gov,workclass_State-gov,workclass_Self-emp-inc,...,native_country_Jamaica,native_country_Ecuador,native_country_Yugoslavia,native_country_Hungary,native_country_Hong,native_country_Greece,native_country_Trinadad&Tobago,native_country_Outlying-US(Guam-USVI-etc),native_country_France,native_country_Holand-Netherlands
8,24,0,0,40,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
12,26,0,0,39,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
17,43,0,0,30,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
18,37,0,0,20,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
21,34,0,0,35,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48827,37,0,0,40,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
48830,43,0,0,40,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
48837,27,0,0,38,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
48839,58,0,0,40,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [11]:
dfx_1[cols] = minmax_scale(dfx_1[cols])
dfx_0[cols] = minmax_scale(dfx_0[cols])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx_1[cols] = minmax_scale(dfx_1[cols])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx_0[cols] = minmax_scale(dfx_0[cols])


In [12]:
df_x1 = dfx_1.sample(1000)
df_x0 = dfx_0.sample(1000)

warm_start_y1 = torch.from_numpy(dfy.loc[df_x1.index].values).to(device)
warm_start_y0 = torch.from_numpy(dfy.loc[df_x0.index].values).to(device)

warm_start_x1 = torch.from_numpy(df_x1.values).float().to(device)
warm_start_x0 = torch.from_numpy(df_x0.values).float().to(device)

In [13]:
model0, likelihood0 = train(warm_start_x0, warm_start_y0)
model1, likelihood1 = train(warm_start_x1, warm_start_y1)

model0.to(device)
likelihood0.to(device)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 26.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 64.12it/s]


DirichletClassificationLikelihood(
  (noise_covar): FixedGaussianNoise()
  (second_noise_covar): HomoskedasticNoise(
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)

In [14]:
model0.eval(), model1.eval(), likelihood0.eval(), likelihood1.eval();

In [15]:
from vae_models import RelaxedBernoulliVAE as rbvae

In [16]:
vae0 = rbvae()
vae0.load_state_dict(torch.load("/mnt/infonas/data/eeshaan/fairness/EE492/checkpoints/adult_income/rbvae_xA=0/best.pt"))
vae0.eval()
vae0.to(device);

In [17]:
vae1 = rbvae()
vae1.load_state_dict(torch.load("/mnt/infonas/data/eeshaan/fairness/EE492/checkpoints/adult_income/rbvae_xA=1/best.pt"))
vae1.eval()
vae1.to(device);

In [18]:
from blackbox_models import BlackBox

In [19]:
blackbox = BlackBox('Logistic', 102, 1)
blackbox.load_state_dict(torch.load("/mnt/infonas/data/eeshaan/fairness/EE492/checkpoints/adult_income/blackbox/Logistic/best.pt"))
blackbox.eval()
blackbox.to(device);

In [None]:
candidates = []
neg_queried = warm_start_x0.clone()
neg_labels = warm_start_y0.clone().unsqueeze(1)

for epoch_outer in tqdm(range(1, 4001)):
    x0_random = torch.normal(0.,1.,size=(1,102), dtype=torch.float32, requires_grad=True)
    optimizer0 = torch.optim.AdamW((x0_random,), lr=10)
    best_loss = 10e5
    count = 0
    losses = []
    for epoch in range(1,100):
        optimizer0.zero_grad()
        x0_samples = utils.postprocess(
            vae0.sample(x0_random.to(device), 100, device, 
                        **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False}).squeeze(1),
            'adult_income'
        )
        obj0 = likelihood0(model0(x0_samples.to(device))).variance.sum(axis=0).mean() + ((x0_random.to(device) - x0_samples)**2).mean()
#         print(obj0.shape)
        loss = -obj0
        loss.backward()
        optimizer0.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 5:
            break
    candidates.append(x0_random.detach().clone())
    
    if epoch_outer % 100 == 0:
        new_vals = torch.concatenate(candidates)
        new_queries,_ = vae0(new_vals.to(device), **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False})
        new_labels = blackbox(new_queries)
        neg_queried = torch.concatenate([neg_queried, new_queries])
        neg_labels = torch.concatenate([neg_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
        print(neg_labels.shape, neg_labels.dtype)
        print(neg_queried.shape, neg_queried.dtype)
        model0, likelihood0 = train(neg_queried.detach().clone(),neg_labels.flatten())
        model0.eval()
        likelihood0.eval()
        candidates = []

  2%|██▌                                                                                                     | 99/4000 [03:34<2:21:58,  2.18s/it]

torch.Size([1100, 1]) torch.int64
torch.Size([1100, 102]) torch.float32



  0%|                                                                                                                     | 0/50 [00:00<?, ?it/s][A
 14%|███████████████▎                                                                                             | 7/50 [00:00<00:00, 62.07it/s][A
 28%|██████████████████████████████▏                                                                             | 14/50 [00:00<00:00, 60.56it/s][A
 42%|█████████████████████████████████████████████▎                                                              | 21/50 [00:00<00:00, 61.04it/s][A
 56%|████████████████████████████████████████████████████████████▍                                               | 28/50 [00:00<00:00, 62.82it/s][A
 70%|███████████████████████████████████████████████████████████████████████████▌                                | 35/50 [00:00<00:00, 63.89it/s][A
 84%|██████████████████████████████████████████████████████████████████████████████████████████▋         

torch.Size([1200, 1]) torch.int64
torch.Size([1200, 102]) torch.float32



  0%|                                                                                                                     | 0/50 [00:00<?, ?it/s][A
 12%|█████████████                                                                                                | 6/50 [00:00<00:00, 52.21it/s][A
 26%|████████████████████████████                                                                                | 13/50 [00:00<00:00, 57.82it/s][A
 40%|███████████████████████████████████████████▏                                                                | 20/50 [00:00<00:00, 59.47it/s][A
 54%|██████████████████████████████████████████████████████████▎                                                 | 27/50 [00:00<00:00, 60.19it/s][A
 68%|█████████████████████████████████████████████████████████████████████████▍                                  | 34/50 [00:00<00:00, 60.74it/s][A
 82%|████████████████████████████████████████████████████████████████████████████████████████▌           

torch.Size([1300, 1]) torch.int64
torch.Size([1300, 102]) torch.float32



  0%|                                                                                                                     | 0/50 [00:00<?, ?it/s][A
 14%|███████████████▎                                                                                             | 7/50 [00:00<00:00, 61.58it/s][A
 28%|██████████████████████████████▏                                                                             | 14/50 [00:00<00:00, 63.49it/s][A
 42%|█████████████████████████████████████████████▎                                                              | 21/50 [00:00<00:00, 64.51it/s][A
 56%|████████████████████████████████████████████████████████████▍                                               | 28/50 [00:00<00:00, 64.95it/s][A
 70%|███████████████████████████████████████████████████████████████████████████▌                                | 35/50 [00:00<00:00, 65.11it/s][A
 84%|██████████████████████████████████████████████████████████████████████████████████████████▋         

torch.Size([1400, 1]) torch.int64
torch.Size([1400, 102]) torch.float32



  0%|                                                                                                                     | 0/50 [00:00<?, ?it/s][A
 12%|█████████████                                                                                                | 6/50 [00:00<00:00, 57.78it/s][A
 26%|████████████████████████████                                                                                | 13/50 [00:00<00:00, 62.21it/s][A
 40%|███████████████████████████████████████████▏                                                                | 20/50 [00:00<00:00, 63.72it/s][A
 54%|██████████████████████████████████████████████████████████▎                                                 | 27/50 [00:00<00:00, 64.89it/s][A
 68%|█████████████████████████████████████████████████████████████████████████▍                                  | 34/50 [00:00<00:00, 65.04it/s][A
 82%|████████████████████████████████████████████████████████████████████████████████████████▌           

In [21]:
candidates = []
pos_queried = warm_start_x1.clone()
pos_labels = warm_start_y1.clone().unsqueeze(1)
print(pos_labels.shape, pos_labels.dtype)
print(pos_queried.shape, pos_queried.dtype)

for epoch_outer in tqdm(range(1, 4001)):
    x1_random = torch.normal(0.,1.,size=(1,102), dtype=torch.float32, requires_grad=True)
    optimizer1 = torch.optim.AdamW((x1_random,), lr=10)
    best_loss = 10e5
    count = 0
    losses = []
    for epoch in range(1,100):
        optimizer1.zero_grad()
        x1_samples = utils.postprocess(
            vae0.sample(x1_random.to(device), 100, device, 
                        **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False}).squeeze(1),
            'adult_income'
        )
        obj1 = likelihood1(model1(x1_samples.to(device))).variance.sum(axis=0).mean() + ((x0_random.to(device) - x0_samples)**2).mean()
#         print(obj0.shape)
        loss = -obj1
        loss.backward()
        optimizer1.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 5:
            break
    candidates.append(x1_random.detach().clone())
    
    if epoch_outer % 100 == 0:
        new_vals = torch.concatenate(candidates)
        new_queries,_ = vae1(new_vals.to(device), **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False})
        new_labels = blackbox(new_queries)
        pos_queried = torch.concatenate([pos_queried, new_queries])
        pos_labels = torch.concatenate([pos_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
        model1, likelihood1 = train(pos_queried.detach().clone(),pos_labels.flatten())
        model1.eval()
        likelihood1.eval()
        candidates = []

torch.Size([1000, 1]) torch.int64
torch.Size([1000, 102]) torch.float32


  2%|███▉                                                                                                                                                          | 99/4000 [00:31<17:27,  3.72it/s]
  0%|                                                                                                                                                                         | 0/50 [00:00<?, ?it/s][A
  8%|████████████▉                                                                                                                                                    | 4/50 [00:00<00:01, 39.59it/s][A
 16%|█████████████████████████▊                                                                                                                                       | 8/50 [00:00<00:01, 39.53it/s][A
 24%|██████████████████████████████████████▍                                                                                                                         | 12/50 [00:00<00:00, 39.35it/s][

In [22]:
pos_labels.shape

torch.Size([5000, 1])

In [23]:
new_labels.shape

torch.Size([100, 1])

In [24]:
# parity in data
np.abs(dfy[dfx[dfx.gender == 0].index].mean() - dfy[dfx[dfx.gender == 1].index].mean())

0.19911019753072282

In [25]:
bb_input0 = torch.from_numpy(dfx_0.values).float().to(device)
bb_input1 = torch.from_numpy(dfx_1.values).float().to(device)

In [28]:
y0 = blackbox(bb_input0)
y1 = blackbox(bb_input1)

In [31]:
y0_ = torch.round(y0)
y1_ = torch.round(y1)

In [32]:
y0_.sum()

tensor(7659., device='cuda:0', grad_fn=<SumBackward0>)

In [37]:
torch.abs(y0_.mean() - y1_.mean())

tensor(0.1869, device='cuda:0', grad_fn=<AbsBackward0>)

In [41]:
torch.abs(pos_labels.squeeze().float().mean() - neg_labels.squeeze().float().mean())

tensor(0.0982, device='cuda:0')

In [42]:
pos_labels.float().mean()

tensor(0.0520, device='cuda:0')

In [44]:
neg_labels.float().mean()

tensor(0.1502, device='cuda:0')