In [1]:
import sys 
sys.path.append('..')
from cox.utils import Parameters
from cox.store import Store
from sklearn.metrics import mean_squared_error
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import numpy as np
import torch as ch
from torch import Tensor
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader
from delphi.stats.truncated_regression import truncated_regression
from delphi.utils.datasets import TruncatedRegressionDataset
from delphi.oracle import Left

# Create Store
Create a store using [cox](https://github.com/MadryLab/cox), MadryLab's light-weight experimental design and analysis framework. 
Here, we create a store to hold the results for the experiment.

In [2]:
STORE_PATH = '<Give Store Path Here>'
STORE_TABLE_NAME = '<Give Store Table Name Here>'
STORE_PATH = '/home/pstefanou/test_'
STORE_TABLE_NAME = 'test'

store = Store(STORE_PATH)

store.add_table(STORE_TABLE_NAME, { 
    'delphi_param_mse': float,
    'delphi_var_mse': float, 
    'var': float, 
    'ols_param_mse': float, 
    'ols_var_mse': float,
    'alpha': float, 
    'c': float,
})

Logging in: /home/pstefanou/test_/2ebd4d1e-b441-4bdd-b13a-f11068b458a9


<cox.store.Table at 0x7f72b6e199e8>

# Experiment
Run experiment, where the truncation parameter C is varied over the range \[-5, 1\]. Each time that the experiment is run, we generate new data, and re-run the same procedure. All results from the experiment are stored within the store that is initialized in the previous cell.

In [None]:
# regression parameters
num_samples, dims = 10000, 10
noise_var = Tensor([5.0])
W = ch.ones(dims, 1)
W0 = ch.ones(1, 1)

# perform each experiment a total of 10 times
for iter_ in range(10):        
    for c in range(-5, 2):
        # generate data
        X = MultivariateNormal(ch.zeros(dims), ch.eye(dims)/dims).sample(ch.Size([num_samples]))
        y = X.mm(W) + W0 + Normal(ch.zeros(1), ch.sqrt(noise_var)).sample(ch.Size([num_samples]))
        # truncate
        phi = Left(Tensor([c]))
        indices = phi(y).nonzero(as_tuple=False).flatten()
        y_trunc, x_trunc = y[indices], X[indices]

        # experiment parameters
        args = Parameters({ 
            'alpha': Tensor([y_trunc.size(0)/num_samples]), 
            'phi': phi, 
            'epochs': 50,
            'num_workers': 2, 
            'batch_size': 100,
            'bias': True,
            'num_samples': 100,
            'clamp': True, 
            'radius': 2.0, 
            'var_lr': 1e-1,
            'lr': 1e-1,
        })


        # dataset 
        data = TruncatedRegressionDataset(x_trunc, y_trunc, bias=args.bias, unknown=True)
        S = DataLoader(data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
        
        trunc_reg = truncated_regression(phi=phi, alpha=args.alpha, epochs=args.epochs, var_lr=args.var_lr, lr=args.lr)

        results = trunc_reg.fit(S)
        var_ = results.lambda_.inverse().detach().cpu()
        w = results.v.detach().cpu()*var_
        w0 = results.bias.detach().cpu()*var_
        # calculate metrics 
        real_params = ch.cat([W, W0])
        ols_params = ch.cat([data.w, data.w0.unsqueeze(0)])
        delphi_params = ch.cat([w, w0])
        delphi_param_mse = mean_squared_error(delphi_params, real_params)
        delphi_var_mse = mean_squared_error(var_, noise_var)
        ols_param_mse = mean_squared_error(ols_params, real_params)
        ols_var_mse = mean_squared_error(noise_var, data.lambda_.inverse())

        store[STORE_TABLE_NAME].append_row({ 
            'delphi_param_mse': delphi_param_mse,
            'delphi_var_mse': delphi_var_mse, 
            'var': float(noise_var), 
            'ols_param_mse': ols_param_mse,
            'ols_var_mse': ols_var_mse,
            'alpha': float(args.alpha.flatten()),
            'c': c
        })

grad.sizes() = [10, 1], strides() = [1, 10]
param.sizes() = [10, 1], strides() = [1, 1] (Triggered internally at  /pytorch/torch/csrc/autograd/functions/accumulate_grad.h:170.)
  allow_unreachable=True)  # allow_unreachable flag
Epoch:1 | Loss 3.7838 | Train1 0.163 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 68.17it/s]
Epoch:2 | Loss 4.3130 | Train1 0.160 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 67.95it/s]
Epoch:3 | Loss 4.9975 | Train1 0.162 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 56.99it/s]
Epoch:4 | Loss 3.7249 | Train1 0.158 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 57.50it/s]
Epoch:5 | Loss 7.4767 | Train1 0.162 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 57.20it/s]
Epoch:6 | Loss 13.9459 | Train1 0.161 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 57.10it/s]
Epoch:7 | Loss 4.3487 | Train1 0

Epoch:12 | Loss 13.4977 | Train1 0.172 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 55.59it/s]
Epoch:13 | Loss 3.7185 | Train1 0.175 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 54.59it/s]
Epoch:14 | Loss 3.7731 | Train1 0.174 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 53.64it/s]
Epoch:15 | Loss 4.1842 | Train1 0.172 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 55.62it/s]
Epoch:16 | Loss 5.4362 | Train1 0.173 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 56.05it/s]
Epoch:17 | Loss 2.8539 | Train1 0.178 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 56.91it/s]
Epoch:18 | Loss 7.8462 | Train1 0.171 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 57.84it/s] 
Epoch:19 | Loss 3.7259 | Train1 0.176 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:01<00:00, 57.80it/s]
Epoch:20 | Loss 7.5013

Epoch:29 | Loss 3.9663 | Train1 0.172 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 67.09it/s]
Epoch:30 | Loss 14.4574 | Train1 0.172 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 68.57it/s]
Epoch:31 | Loss 4.3231 | Train1 0.170 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 67.44it/s]
Epoch:32 | Loss 3.3807 | Train1 0.173 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 65.37it/s]
Epoch:33 | Loss 4.2170 | Train1 0.172 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.06it/s]
Epoch:34 | Loss 15.5301 | Train1 0.171 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.47it/s]
Epoch:35 | Loss 5.5427 | Train1 0.175 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.04it/s]
Epoch:36 | Loss 7.8050 | Train1 0.164 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 65.60it/s] 
Epoch:37 | Loss 7.104

Epoch:46 | Loss 3.0288 | Train1 0.184 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 89/89 [00:01<00:00, 66.21it/s]
Epoch:47 | Loss 3.0512 | Train1 0.184 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 89/89 [00:01<00:00, 68.02it/s]
Epoch:48 | Loss 7.7299 | Train1 0.182 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 89/89 [00:01<00:00, 54.46it/s]
Epoch:49 | Loss 2.3024 | Train1 0.185 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 89/89 [00:01<00:00, 57.16it/s]
Epoch:50 | Loss 2.6976 | Train1 0.186 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 89/89 [00:01<00:00, 55.03it/s]
Epoch:1 | Loss 3.1876 | Train1 0.207 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 54.96it/s]
Epoch:2 | Loss 2.6264 | Train1 0.212 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 54.87it/s]
Epoch:3 | Loss 3.0067 | Train1 0.213 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 55.33it/s]
Epoch:4 | Loss 8.4510 | Tra

In [None]:
results = store[STORE_TABLE_NAME].df
results.head()

# Plot Results

In [None]:
# plot results for regression parameter MSE
sns.lineplot(data=results, x='c', y='delphi_param_mse', label='delphi', color='blue')
ax = sns.lineplot(data=results, x='c', y='ols_param_mse', label='ols', color='red')
ax.set(xlabel='Truncation Parameter C', ylabel='MSE')
plt.show()

In [None]:
# plot results for regression noise variance MSE
sns.lineplot(data=results, x='c', y='delphi_var_mse', label='delphi', color="blue")
ax = sns.lineplot(data=results, x='c', y='ols_var_mse', label='ols', color="red")
ax.set(xlabel='Truncation Parameter C', ylabel='MSE')
plt.show()