In [49]:
import sys 
sys.path.append('../..')
from cox.utils import Parameters
from cox.store import Store
from cox.readers import CollectionReader
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import itertools
import numpy as np
import torch as ch
from torch import Tensor
import torch.nn as nn
from torch.distributions import Gumbel, Uniform
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import datetime
from delphi.oracle import oracle

# set default tensor type 
ch.set_default_tensor_type(ch.cuda.FloatTensor)

## Default Experiment Parameters

In [137]:
# procedure hyperparameters
args = Parameters({ 
    'epochs': 25,
    'num_workers': 0, 
    'batch_size': 100,
    'bias': True,
    'num_samples': 1000,
    'clamp': True, 
    'radius': 5.0, 
    'var_lr': 1e-2,
    'lr': 1e-1,
    'shuffle': False, 
    'samples': 10000,  # number of samples to generate for ground truth
    'in_features': 10, # number of in-features to multi-log-reg
    'k': 10, # number of classes
    'lower': -1, # lower bound for generating ground truth weights
    'upper': 1,  # upper bound for generating ground truth weights
    'trials': 10,
})

# CE Latent Variable Model Loss

In [122]:
gumbel = Gumbel(0, 1)

class GumbelCE(ch.autograd.Function):
    @staticmethod
    def forward(ctx, pred, targ):
        ctx.save_for_backward(pred, targ)
        loss = ch.nn.CrossEntropyLoss()
        return loss(pred, targ)

    @staticmethod
    def backward(ctx, grad_output):
        pred, targ = ctx.saved_tensors
        # make num_samples copies of pred logits
        stacked = pred[None, ...].repeat(1000, 1, 1)        
        # add gumbel noise to logits
        rand_noise = gumbel.sample(stacked.size())
        noised = stacked + rand_noise 
        noised_labs = noised.argmax(-1)
        # remove the logits from the trials, where the kth logit is not the largest value
        good_mask = noised_labs.eq(targ)[..., None]
        inner_exp = 1 - ch.exp(-rand_noise)
        avg = (inner_exp * good_mask).sum(0) / (good_mask.sum(0) + 1e-5) / pred.size(0)
        return -avg , None
    
class TruncatedGumbelCE(ch.autograd.Function):
    @staticmethod
    def forward(ctx, pred, targ, phi):
        ctx.save_for_backward(pred, targ)
        ctx.phi = phi
        ce_loss = ch.nn.CrossEntropyLoss()
        return ce_loss(pred, targ)

    @staticmethod
    def backward(ctx, grad_output):
        pred, targ = ctx.saved_tensors
        # make num_samples copies of pred logits
        stacked = pred[None, ...].repeat(args.num_samples, 1, 1)   
        # add gumbel noise to logits
        rand_noise = gumbel.sample(stacked.size())
        noised = stacked + rand_noise 
        # truncate - if one of the noisy logits does not fall within the truncation set, remove it
        filtered = ch.all(ctx.phi(noised).bool(), dim=2).float().unsqueeze(2)
        noised_labs = noised.argmax(-1)
        # mask takes care of invalid logits and truncation set
        mask = noised_labs.eq(targ)[..., None] * filtered
        inner_exp = 1 - ch.exp(-rand_noise)
                
        avg = ((inner_exp * mask).sum(0) / (mask.sum(0) + 1e-5) - (inner_exp * filtered).sum(0) / (filtered.sum(0) + 1e-5)) 
        return -avg / pred.size(0), None, None

# Truncated Multinomial Logistic Regression Experiments

Membership oracles for Multinomial Logistic Regression Logits 

In [52]:
class DNN_Lower(oracle): 
    """
    Lower bound truncation on the DNN logits.
    """
    def __init__(self, lower): 
        self.lower = lower
        
    def __call__(self, x): 
        return (x > self.lower).float()

In [53]:
class Identity(oracle): 
    def __call__(self, x): 
        return ch.ones(x.size())

Truncate Dataset

In [147]:
phi = DNN_Lower(ch.full(ch.Size([args.K,]), -5))
# phi = Identity()

In [148]:
# TRUNC CE LOSS TABLE FOR METRICS
LATENT_CE_TABLE_NAME = 'trunc_test_9'

STORE_PATH = '/home/pstefanou/MultinomialLogisticRegression'
store = Store(STORE_PATH)

store.add_table(LATENT_CE_TABLE_NAME, { 
    'trunc_train_acc': float, 
    'trunc_val_acc': float, 
    'trunc_train_loss': float, 
    'trunc_val_loss': float,
    'naive_train_acc': float, 
    'naive_val_acc': float, 
    'naive_train_loss': float, 
    'naive_val_loss': float,
    'trunc_test_acc': float, 
    'naive_test_acc': float,
    'epoch': int,
})

Logging in: /home/pstefanou/MultinomialLogisticRegression/6223ad17-746c-428e-b3c2-27349706ff4f


<cox.store.Table at 0x7faf47940048>

In [None]:
U = Uniform(args.lower, args.upper) # distribution to generate ground-truth parameters
U_ = Uniform(-5, 5) # distribution to generate samples

# perform trials number of experiments
for i in range(args.trials):
    
    # continue to generate synthetic data until survival probability of more than 40%
    alpha = None
    while alpha is None or alpha < .5:
        # generate ground-truth from uniform distribution
        ground_truth = nn.Linear(in_features=args.IN_FEATURES, out_features=args.K, bias=args.bias)
        ground_truth.weight = nn.Parameter(U.sample(ch.Size([args.K, args.IN_FEATURES])))
        if ground_truth.bias is not None: 
            ground_truth.bias = nn.Parameter(U.sample(ch.Size([args.K,])))
        # independent variable 
        X = U_.sample(ch.Size([args.samples, args.IN_FEATURES]))
        # determine base model logits 
        z = ground_truth(X)
        # apply softmax to unnormalized likelihoods
        y = ch.argmax(ch.nn.Softmax(dim=1)(z), dim=1)

        # TRUNCATE
        trunc = phi(z)
        indices = ch.all(trunc.bool(), dim=1).float().nonzero(as_tuple=False).flatten()
        y_trunc = y[indices]
        x_trunc = X[indices]
        alpha = x_trunc.size(0) / X.size(0)

        # all synthetic data 
        ds = TensorDataset(x_trunc, y_trunc)
        # split ds into training and validation data sets - 80% training, 20% validation
        train_length = int(len(ds)*.8)
        val_length = len(ds) - train_length
        train_ds, val_ds = ch.utils.data.random_split(ds, [train_length, val_length])
        # train and validation loaders
        train_loader = DataLoader(train_ds, num_workers=args.num_workers, batch_size=args.batch_size)
        val_loader = DataLoader(val_ds, num_workers=args.num_workers, batch_size=args.batch_size)

        # test dataset
        y_test = y[~indices]
        x_test = X[~indices]
        
        print('alpha: {}'.format(alpha))
        
    print("alpha: {}".format(alpha))

    
    # reset classifier models at the beginning of each trial
    trunc_multi_log_reg = nn.Linear(in_features=args.IN_FEATURES, out_features=args.K, bias=args.bias)
    naive_multi_log_reg = nn.Linear(in_features=args.IN_FEATURES, out_features=args.K, bias=args.bias)
    # optimizer and scheduler
    trunc_opt = ch.optim.SGD(trunc_multi_log_reg.parameters(), lr=1e-1)
    naive_opt = ch.optim.SGD(naive_multi_log_reg.parameters(), lr=1e-1)
    # use cosine annealing for learning rate scheduler
    trunc_scheduler = ch.optim.lr_scheduler.CosineAnnealingLR(trunc_opt, args.epochs)
    naive_scheduler = ch.optim.lr_scheduler.CosineAnnealingLR(naive_opt, args.epochs)
    # gradients
    trunc_ce_loss = TruncatedGumbelCE.apply
    ce_loss = ch.nn.CrossEntropyLoss()
    
    # train classifier
    for epoch in range(args.epochs): 
        # train loop
        trunc_train_loss, trunc_train_acc = Tensor([]), Tensor([])
        naive_train_loss, naive_train_acc = Tensor([]), Tensor([])
        for batch_X, batch_y in train_loader: 
            # truncated multinomial regression
            trunc_opt.zero_grad()
            pred = trunc_multi_log_reg(batch_X)
            loss = trunc_ce_loss(pred, batch_y, phi)
            loss.backward() 
            trunc_opt.step()
            trunc_scheduler.step()
            # keep track of truncated algorithm training loss and accuracy
            acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)
            trunc_train_loss = ch.cat([trunc_train_loss, Tensor([loss])]) if trunc_train_loss.size() != ch.Size([0]) else Tensor([loss])
            trunc_train_acc = ch.cat([trunc_train_acc, Tensor([acc])]) if trunc_train_acc.size() != ch.Size([0]) else Tensor([acc])
            
            # naive multinomial regression
            naive_opt.zero_grad()
            pred = naive_multi_log_reg(batch_X)
            loss = ce_loss(pred, batch_y)
            loss.backward() 
            naive_opt.step()
            naive_scheduler.step()
            # keep track of naive algorithm training loss and accuracy
            acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)
            naive_train_loss = ch.cat([naive_train_loss, Tensor([loss])]) if naive_train_loss.size() != ch.Size([0]) else Tensor([loss])
            naive_train_acc = ch.cat([naive_train_acc, Tensor([acc])]) if naive_train_acc.size() != ch.Size([0]) else Tensor([acc])
        # validation loop
        trunc_val_loss, trunc_val_acc = Tensor([]), Tensor([])
        naive_val_loss, naive_val_acc = Tensor([]), Tensor([])
        with ch.no_grad(): 
            for batch_X, batch_y in val_loader: 
                # truncated validation loop
                pred = trunc_multi_log_reg(batch_X)
                loss = trunc_ce_loss(pred, batch_y, phi)
                # keep track of algorithm validation loss and accuracy
                acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)            
                trunc_val_loss = ch.cat([trunc_val_loss, Tensor([loss])]) if trunc_val_loss.size() != ch.Size([0]) else Tensor([loss])
                trunc_val_acc = ch.cat([trunc_val_acc, Tensor([acc])]) if trunc_val_acc.size() != ch.Size([0]) else Tensor([acc])
                
                # naive validation loop
                pred = naive_multi_log_reg(batch_X)
                loss = ce_loss(pred, batch_y)
                # keep track of algorithm validation loss and accuracy
                acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)            
                naive_val_loss = ch.cat([naive_val_loss, Tensor([loss])]) if naive_val_loss.size() != ch.Size([0]) else Tensor([loss])
                naive_val_acc = ch.cat([naive_val_acc, Tensor([acc])]) if naive_val_acc.size() != ch.Size([0]) else Tensor([acc])

            # test set accuracy
            trunc_test_pred = trunc_multi_log_reg(x_test)
            naive_test_pred = naive_multi_log_reg(x_test)
            trunc_test_acc = (ch.argmax(ch.nn.Softmax(dim=1)(trunc_test_pred), dim=1) == y_test).sum() / y.size(0)
            naive_test_acc = (ch.argmax(ch.nn.Softmax(dim=1)(naive_test_pred), dim=1) == y_test).sum() / y.size(0)
                
        store[LATENT_CE_TABLE_NAME].append_row({ 
            'trunc_train_acc': float(trunc_train_acc.mean()), 
            'trunc_val_acc': float(trunc_val_acc.mean()), 
            'trunc_train_loss': float(trunc_train_loss.mean()), 
            'trunc_val_loss': float(trunc_val_loss.mean()),
            'naive_train_acc': float(naive_train_acc.mean()), 
            'naive_val_acc': float(naive_val_acc.mean()), 
            'naive_train_loss': float(naive_train_loss.mean()), 
            'naive_val_loss': float(naive_val_loss.mean()),
            'trunc_test_acc': float(trunc_test_acc), 
            'naive_test_acc': float(naive_test_acc),
            'epoch': int(epoch + 1),
        })
    
store.close()

alpha: 0.1152
alpha: 0.2724
alpha: 0.1377
alpha: 0.1331
alpha: 0.1183
alpha: 0.1196
alpha: 0.1312
alpha: 0.1542
alpha: 0.1612
alpha: 0.174
alpha: 0.1066
alpha: 0.1641
alpha: 0.1478
alpha: 0.1903
alpha: 0.1547
alpha: 0.166
alpha: 0.223
alpha: 0.1476
alpha: 0.13
alpha: 0.149
alpha: 0.1802
alpha: 0.1276
alpha: 0.1405
alpha: 0.1186
alpha: 0.1302
alpha: 0.1231
alpha: 0.106
alpha: 0.1405
alpha: 0.1402
alpha: 0.1364
alpha: 0.1805
alpha: 0.139
alpha: 0.1607
alpha: 0.1404
alpha: 0.1591
alpha: 0.1618
alpha: 0.1549
alpha: 0.1814
alpha: 0.1738
alpha: 0.1498
alpha: 0.1589
alpha: 0.1504
alpha: 0.1764
alpha: 0.1211
alpha: 0.0652
alpha: 0.0725
alpha: 0.1755
alpha: 0.1021
alpha: 0.2049
alpha: 0.2368
alpha: 0.2531
alpha: 0.1472
alpha: 0.1614
alpha: 0.0866
alpha: 0.1682
alpha: 0.0657
alpha: 0.1978
alpha: 0.2124
alpha: 0.1827
alpha: 0.214
alpha: 0.2314
alpha: 0.1567
alpha: 0.1317
alpha: 0.1443
alpha: 0.2134
alpha: 0.1369
alpha: 0.1004
alpha: 0.1216
alpha: 0.1173
alpha: 0.1389
alpha: 0.134
alpha: 0.194
alp

alpha: 0.1726
alpha: 0.2101
alpha: 0.2019
alpha: 0.1365
alpha: 0.185
alpha: 0.2223
alpha: 0.1928
alpha: 0.1648
alpha: 0.1325
alpha: 0.1499
alpha: 0.1395
alpha: 0.1256
alpha: 0.1264
alpha: 0.1512
alpha: 0.2451
alpha: 0.1334
alpha: 0.1392
alpha: 0.1756
alpha: 0.1445
alpha: 0.268
alpha: 0.1614
alpha: 0.1558
alpha: 0.1353
alpha: 0.1722
alpha: 0.1375
alpha: 0.1088
alpha: 0.206
alpha: 0.1179
alpha: 0.1973
alpha: 0.2084
alpha: 0.1982
alpha: 0.1382
alpha: 0.1448
alpha: 0.1096
alpha: 0.2414
alpha: 0.1034
alpha: 0.1452
alpha: 0.2191
alpha: 0.1964
alpha: 0.1608
alpha: 0.0529
alpha: 0.1404
alpha: 0.1686
alpha: 0.2244
alpha: 0.1921
alpha: 0.2333
alpha: 0.2062
alpha: 0.2419
alpha: 0.1663
alpha: 0.1837
alpha: 0.062
alpha: 0.2396
alpha: 0.1174
alpha: 0.2329
alpha: 0.1651
alpha: 0.1255
alpha: 0.1925
alpha: 0.134
alpha: 0.1576
alpha: 0.1644
alpha: 0.1959
alpha: 0.2137
alpha: 0.1305
alpha: 0.1214
alpha: 0.1492
alpha: 0.1259
alpha: 0.1639
alpha: 0.2264
alpha: 0.1435
alpha: 0.1597
alpha: 0.1587
alpha: 0.11

alpha: 0.1045
alpha: 0.1507
alpha: 0.1873
alpha: 0.1603
alpha: 0.1049
alpha: 0.1518
alpha: 0.2252
alpha: 0.1644
alpha: 0.1926
alpha: 0.1514
alpha: 0.1467
alpha: 0.1851
alpha: 0.1641
alpha: 0.0986
alpha: 0.2384
alpha: 0.1861
alpha: 0.0746
alpha: 0.1423
alpha: 0.1589
alpha: 0.1774
alpha: 0.1467
alpha: 0.1544
alpha: 0.1854
alpha: 0.166
alpha: 0.1034
alpha: 0.0945
alpha: 0.1159
alpha: 0.1577
alpha: 0.1797
alpha: 0.1732
alpha: 0.1732
alpha: 0.1254
alpha: 0.1124
alpha: 0.2204
alpha: 0.1962
alpha: 0.1044
alpha: 0.2084
alpha: 0.1666
alpha: 0.0952
alpha: 0.1669
alpha: 0.1727
alpha: 0.2741
alpha: 0.1089
alpha: 0.186
alpha: 0.1326
alpha: 0.2695
alpha: 0.0839
alpha: 0.1949
alpha: 0.0929
alpha: 0.182
alpha: 0.1822
alpha: 0.1305
alpha: 0.1549
alpha: 0.1787
alpha: 0.1197
alpha: 0.1598
alpha: 0.1358
alpha: 0.2436
alpha: 0.0737
alpha: 0.1683
alpha: 0.2239
alpha: 0.2053
alpha: 0.0974
alpha: 0.1401
alpha: 0.1667
alpha: 0.1932
alpha: 0.1312
alpha: 0.1239
alpha: 0.171
alpha: 0.1836
alpha: 0.1649
alpha: 0.2

alpha: 0.1004
alpha: 0.1042
alpha: 0.1771
alpha: 0.1647
alpha: 0.2231
alpha: 0.2255
alpha: 0.1831
alpha: 0.1519
alpha: 0.1751
alpha: 0.1659
alpha: 0.2071
alpha: 0.1775
alpha: 0.1125
alpha: 0.2219
alpha: 0.182
alpha: 0.166
alpha: 0.2149
alpha: 0.1612
alpha: 0.1716
alpha: 0.0978
alpha: 0.1751
alpha: 0.1413
alpha: 0.2771
alpha: 0.2328
alpha: 0.1995
alpha: 0.2614
alpha: 0.1476
alpha: 0.0947
alpha: 0.1223
alpha: 0.2225
alpha: 0.1185
alpha: 0.1454
alpha: 0.1851
alpha: 0.1769
alpha: 0.1813
alpha: 0.2281
alpha: 0.1849
alpha: 0.1497
alpha: 0.2026
alpha: 0.1107
alpha: 0.1442
alpha: 0.1038
alpha: 0.191
alpha: 0.1626
alpha: 0.0949
alpha: 0.1185
alpha: 0.127
alpha: 0.2043
alpha: 0.1304
alpha: 0.1584
alpha: 0.0999
alpha: 0.164
alpha: 0.1525
alpha: 0.2113
alpha: 0.1288
alpha: 0.1224
alpha: 0.1564
alpha: 0.1343
alpha: 0.1738
alpha: 0.1353
alpha: 0.1453
alpha: 0.1778
alpha: 0.1698
alpha: 0.1869
alpha: 0.1702
alpha: 0.1285
alpha: 0.1428
alpha: 0.1235
alpha: 0.1431
alpha: 0.0773
alpha: 0.1111
alpha: 0.21

# Read Experiment Data from Store

In [None]:
reader = CollectionReader(STORE_PATH)
results = reader.df(LATENT_CE_TABLE_NAME)
reader.close() # close reader
results.head()

# Loss and Accuracy Curves

In [None]:
sns.lineplot(data=results, x='epoch', y='trunc_train_loss', label='trunc train loss')
sns.lineplot(data=results, x='epoch', y='naive_train_loss', label='naive train loss')
sns.lineplot(data=results, x='epoch', y='trunc_val_loss', color='red', label='trunc val loss')
ax = sns.lineplot(data=results, x='epoch', y='naive_val_loss', color='red', label='naive val loss')
ax.set(xlabel='epoch', ylabel='CE Loss')
plt.show()

sns.lineplot(data=results, x='epoch', y='trunc_train_acc', label='trunc train acc')
sns.lineplot(data=results, x='epoch', y='naive_train_acc', label='naive train acc')
sns.lineplot(data=results, x='epoch', y='trunc_val_acc', label='trunc val acc')
ax = sns.lineplot(data=results, x='epoch', y='naive_val_acc', label='naive val acc')
ax.set(xlabel='epoch', ylabel='Accuracy')
plt.show()

sns.lineplot(data=results, x='epoch', y='trunc_test_acc', label='trunc test acc')
ax = sns.lineplot(data=results, x='epoch', y='naive_test_acc', label='naive test acc')
ax.set(xlabel='epoch', ylabel='Test Accuracy')
plt.show()