In [49]:
import sys 
sys.path.append('../..')
from cox.utils import Parameters
from cox.store import Store
from cox.readers import CollectionReader
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import itertools
import numpy as np
import torch as ch
from torch import Tensor
import torch.nn as nn
from torch.distributions import Gumbel, Uniform
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import datetime
from delphi.oracle import oracle

# set default tensor type 
ch.set_default_tensor_type(ch.cuda.FloatTensor)

## Default Experiment Parameters

In [154]:
# procedure hyperparameters
args = Parameters({ 
    'epochs': 25,
    'num_workers': 0, 
    'batch_size': 100,
    'bias': True,
    'num_samples': 100000,
    'clamp': True, 
    'radius': 5.0, 
    'var_lr': 1e-2,
    'lr': 1e-1,
    'shuffle': False, 
    'samples': 10000,  # number of samples to generate for ground truth
    'in_features': 10, # number of in-features to multi-log-reg
    'k': 10, # number of classes
    'lower': -1, # lower bound for generating ground truth weights
    'upper': 1,  # upper bound for generating ground truth weights
    'trials': 10,
})

# CE Latent Variable Model Loss

In [153]:
gumbel = Gumbel(0, 1)

class GumbelCE(ch.autograd.Function):
    @staticmethod
    def forward(ctx, pred, targ):
        ctx.save_for_backward(pred, targ)
        loss = ch.nn.CrossEntropyLoss()
        return loss(pred, targ)

    @staticmethod
    def backward(ctx, grad_output):
        pred, targ = ctx.saved_tensors
        # make num_samples copies of pred logits
        stacked = pred[None, ...].repeat(args.num_samples, 1, 1)        
        # add gumbel noise to logits
        rand_noise = gumbel.sample(stacked.size())
        noised = stacked + rand_noise 
        noised_labs = noised.argmax(-1)
        # remove the logits from the trials, where the kth logit is not the largest value
        good_mask = noised_labs.eq(targ)[..., None]
        inner_exp = 1 - ch.exp(-rand_noise)
        avg = (inner_exp * good_mask).sum(0) / (good_mask.sum(0) + 1e-5) / pred.size(0)
        return -avg , None
    
class TruncatedGumbelCE(ch.autograd.Function):
    @staticmethod
    def forward(ctx, pred, targ, phi):
        ctx.save_for_backward(pred, targ)
        ctx.phi = phi
        ce_loss = ch.nn.CrossEntropyLoss()
        return ce_loss(pred, targ)

    @staticmethod
    def backward(ctx, grad_output):
        pred, targ = ctx.saved_tensors
        # make num_samples copies of pred logits
        stacked = pred[None, ...].repeat(args.num_samples, 1, 1)   
        # add gumbel noise to logits
        rand_noise = gumbel.sample(stacked.size())
        noised = stacked + rand_noise 
        # truncate - if one of the noisy logits does not fall within the truncation set, remove it
        filtered = ch.all(ctx.phi(noised).bool(), dim=2).float().unsqueeze(2)
        noised_labs = noised.argmax(-1)
        # mask takes care of invalid logits and truncation set
        mask = noised_labs.eq(targ)[..., None] * filtered
        inner_exp = 1 - ch.exp(-rand_noise)
                
        avg = ((inner_exp * mask).sum(0) / (mask.sum(0) + 1e-5) - (inner_exp * filtered).sum(0) / (filtered.sum(0) + 1e-5)) 
        return -avg / pred.size(0), None, None

# Truncated Multinomial Logistic Regression Experiments

Membership oracles for Multinomial Logistic Regression Logits 

In [52]:
class DNN_Lower(oracle): 
    """
    Lower bound truncation on the DNN logits.
    """
    def __init__(self, lower): 
        self.lower = lower
        
    def __call__(self, x): 
        return (x > self.lower).float()

In [53]:
class Identity(oracle): 
    def __call__(self, x): 
        return ch.ones(x.size())

Truncate Dataset

In [155]:
phi = DNN_Lower(ch.full(ch.Size([args.K,]), -7))
# phi = Identity()

In [156]:
# TRUNC CE LOSS TABLE FOR METRICS
LATENT_CE_TABLE_NAME = 'trunc_test_10'

STORE_PATH = '/home/pstefanou/MultinomialLogisticRegression'
store = Store(STORE_PATH)

store.add_table(LATENT_CE_TABLE_NAME, { 
    'trunc_train_acc': float, 
    'trunc_val_acc': float, 
    'trunc_train_loss': float, 
    'trunc_val_loss': float,
    'naive_train_acc': float, 
    'naive_val_acc': float, 
    'naive_train_loss': float, 
    'naive_val_loss': float,
    'trunc_test_acc': float, 
    'naive_test_acc': float,
    'epoch': int,
})

Logging in: /home/pstefanou/MultinomialLogisticRegression/2c04de1a-c9d1-4c79-8f6f-a3bac6c9dada


<cox.store.Table at 0x7faf37373978>

In [157]:
U = Uniform(args.lower, args.upper) # distribution to generate ground-truth parameters
U_ = Uniform(-5, 5) # distribution to generate samples

# perform trials number of experiments
for i in range(args.trials):
    
    # continue to generate synthetic data until survival probability of more than 40%
    alpha = None
    while alpha is None or alpha < .4:
        # generate ground-truth from uniform distribution
        ground_truth = nn.Linear(in_features=args.IN_FEATURES, out_features=args.K, bias=args.bias)
        ground_truth.weight = nn.Parameter(U.sample(ch.Size([args.K, args.IN_FEATURES])))
        if ground_truth.bias is not None: 
            ground_truth.bias = nn.Parameter(U.sample(ch.Size([args.K,])))
        # independent variable 
        X = U_.sample(ch.Size([args.samples, args.IN_FEATURES]))
        # determine base model logits 
        z = ground_truth(X)
        # apply softmax to unnormalized likelihoods
        y = ch.argmax(ch.nn.Softmax(dim=1)(z), dim=1)

        # TRUNCATE
        trunc = phi(z)
        indices = ch.all(trunc.bool(), dim=1).float().nonzero(as_tuple=False).flatten()
        y_trunc = y[indices]
        x_trunc = X[indices]
        alpha = x_trunc.size(0) / X.size(0)

        # all synthetic data 
        ds = TensorDataset(x_trunc, y_trunc)
        # split ds into training and validation data sets - 80% training, 20% validation
        train_length = int(len(ds)*.8)
        val_length = len(ds) - train_length
        train_ds, val_ds = ch.utils.data.random_split(ds, [train_length, val_length])
        # train and validation loaders
        train_loader = DataLoader(train_ds, num_workers=args.num_workers, batch_size=args.batch_size)
        val_loader = DataLoader(val_ds, num_workers=args.num_workers, batch_size=args.batch_size)

        # test dataset
        y_test = y[~indices]
        x_test = X[~indices]
    
    # reset classifier models at the beginning of each trial
    trunc_multi_log_reg = nn.Linear(in_features=args.IN_FEATURES, out_features=args.K, bias=args.bias)
    naive_multi_log_reg = nn.Linear(in_features=args.IN_FEATURES, out_features=args.K, bias=args.bias)
    # optimizer and scheduler
    trunc_opt = ch.optim.SGD(trunc_multi_log_reg.parameters(), lr=1e-1)
    naive_opt = ch.optim.SGD(naive_multi_log_reg.parameters(), lr=1e-1)
    # use cosine annealing for learning rate scheduler
    trunc_scheduler = ch.optim.lr_scheduler.CosineAnnealingLR(trunc_opt, args.epochs)
    naive_scheduler = ch.optim.lr_scheduler.CosineAnnealingLR(naive_opt, args.epochs)
    # gradients
    trunc_ce_loss = TruncatedGumbelCE.apply
    ce_loss = ch.nn.CrossEntropyLoss()
    
    # train classifier
    for epoch in range(args.epochs): 
        # train loop
        trunc_train_loss, trunc_train_acc = Tensor([]), Tensor([])
        naive_train_loss, naive_train_acc = Tensor([]), Tensor([])
        for batch_X, batch_y in train_loader: 
            # truncated multinomial regression
            trunc_opt.zero_grad()
            pred = trunc_multi_log_reg(batch_X)
            loss = trunc_ce_loss(pred, batch_y, phi)
            loss.backward() 
            trunc_opt.step()
            trunc_scheduler.step()
            # keep track of truncated algorithm training loss and accuracy
            acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)
            trunc_train_loss = ch.cat([trunc_train_loss, Tensor([loss])]) if trunc_train_loss.size() != ch.Size([0]) else Tensor([loss])
            trunc_train_acc = ch.cat([trunc_train_acc, Tensor([acc])]) if trunc_train_acc.size() != ch.Size([0]) else Tensor([acc])
            
            # naive multinomial regression
            naive_opt.zero_grad()
            pred = naive_multi_log_reg(batch_X)
            loss = ce_loss(pred, batch_y)
            loss.backward() 
            naive_opt.step()
            naive_scheduler.step()
            # keep track of naive algorithm training loss and accuracy
            acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)
            naive_train_loss = ch.cat([naive_train_loss, Tensor([loss])]) if naive_train_loss.size() != ch.Size([0]) else Tensor([loss])
            naive_train_acc = ch.cat([naive_train_acc, Tensor([acc])]) if naive_train_acc.size() != ch.Size([0]) else Tensor([acc])
        # validation loop
        trunc_val_loss, trunc_val_acc = Tensor([]), Tensor([])
        naive_val_loss, naive_val_acc = Tensor([]), Tensor([])
        with ch.no_grad(): 
            for batch_X, batch_y in val_loader: 
                # truncated validation loop
                pred = trunc_multi_log_reg(batch_X)
                loss = trunc_ce_loss(pred, batch_y, phi)
                # keep track of algorithm validation loss and accuracy
                acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)            
                trunc_val_loss = ch.cat([trunc_val_loss, Tensor([loss])]) if trunc_val_loss.size() != ch.Size([0]) else Tensor([loss])
                trunc_val_acc = ch.cat([trunc_val_acc, Tensor([acc])]) if trunc_val_acc.size() != ch.Size([0]) else Tensor([acc])
                
                # naive validation loop
                pred = naive_multi_log_reg(batch_X)
                loss = ce_loss(pred, batch_y)
                # keep track of algorithm validation loss and accuracy
                acc = (ch.argmax(ch.nn.Softmax(dim=1)(pred), dim=1) == batch_y).sum() / batch_y.size(0)            
                naive_val_loss = ch.cat([naive_val_loss, Tensor([loss])]) if naive_val_loss.size() != ch.Size([0]) else Tensor([loss])
                naive_val_acc = ch.cat([naive_val_acc, Tensor([acc])]) if naive_val_acc.size() != ch.Size([0]) else Tensor([acc])

            # test set accuracy
            trunc_test_pred = trunc_multi_log_reg(x_test)
            naive_test_pred = naive_multi_log_reg(x_test)
            trunc_test_acc = (ch.argmax(ch.nn.Softmax(dim=1)(trunc_test_pred), dim=1) == y_test).sum() / y.size(0)
            naive_test_acc = (ch.argmax(ch.nn.Softmax(dim=1)(naive_test_pred), dim=1) == y_test).sum() / y.size(0)
                
        store[LATENT_CE_TABLE_NAME].append_row({ 
            'trunc_train_acc': float(trunc_train_acc.mean()), 
            'trunc_val_acc': float(trunc_val_acc.mean()), 
            'trunc_train_loss': float(trunc_train_loss.mean()), 
            'trunc_val_loss': float(trunc_val_loss.mean()),
            'naive_train_acc': float(naive_train_acc.mean()), 
            'naive_val_acc': float(naive_val_acc.mean()), 
            'naive_train_loss': float(naive_train_loss.mean()), 
            'naive_val_loss': float(naive_val_loss.mean()),
            'trunc_test_acc': float(trunc_test_acc), 
            'naive_test_acc': float(naive_test_acc),
            'epoch': int(epoch + 1),
        })
    
store.close()

# Read Experiment Data from Store

In [159]:
reader = CollectionReader(STORE_PATH)
results = reader.df(LATENT_CE_TABLE_NAME)
reader.close() # close reader
results.head()

  8%|▊         | 1/13 [00:00<00:00, 107.98it/s]


ValueError: The file '/home/pstefanou/MultinomialLogisticRegression/30d572c4-b61d-47c1-9bd1-0c81583b5d8a/store.h5' is already opened, but not in read-only mode (as requested).

# Loss and Accuracy Curves

In [None]:
sns.lineplot(data=results, x='epoch', y='trunc_train_loss', label='trunc train loss')
sns.lineplot(data=results, x='epoch', y='naive_train_loss', label='naive train loss')
sns.lineplot(data=results, x='epoch', y='trunc_val_loss', color='red', label='trunc val loss')
ax = sns.lineplot(data=results, x='epoch', y='naive_val_loss', color='red', label='naive val loss')
ax.set(xlabel='epoch', ylabel='CE Loss')
plt.show()

sns.lineplot(data=results, x='epoch', y='trunc_train_acc', label='trunc train acc')
sns.lineplot(data=results, x='epoch', y='naive_train_acc', label='naive train acc')
sns.lineplot(data=results, x='epoch', y='trunc_val_acc', label='trunc val acc')
ax = sns.lineplot(data=results, x='epoch', y='naive_val_acc', label='naive val acc')
ax.set(xlabel='epoch', ylabel='Accuracy')
plt.show()

sns.lineplot(data=results, x='epoch', y='trunc_test_acc', label='trunc test acc')
ax = sns.lineplot(data=results, x='epoch', y='naive_test_acc', label='naive test acc')
ax.set(xlabel='epoch', ylabel='Test Accuracy')
plt.show()