In [7]:
import sys 
sys.path.append('..')
from cox.utils import Parameters
from cox.store import Store
from sklearn.metrics import mean_squared_error
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import numpy as np
import torch as ch
from torch import Tensor
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader
from delphi.stats.truncated_regression import truncated_regression
from delphi.utils.datasets import TruncatedRegressionDataset
from delphi.oracle import Left

# Create Store
Create a store using [cox](https://github.com/MadryLab/cox), MadryLab's light-weight experimental design and analysis framework. 
Here, we create a store to hold the results for the experiment.

In [11]:
STORE_PATH = '<Give Store Path Here>'
STORE_TABLE_NAME = '<Give Store Table Name Here>'
STORE_PATH = '/home/pstefanou/test_'
STORE_TABLE_NAME = 'test'

store = Store(STORE_PATH)

store.add_table(STORE_TABLE_NAME, { 
    'delphi_param_mse': float,
    'delphi_var_mse': float, 
    'var': float, 
    'ols_param_mse': float, 
    'ols_var_mse': float,
    'alpha': float, 
    'c': float,
})

Logging in: /home/pstefanou/test_/dd2bfcd9-87cd-4f7d-86b8-db928fce09b0


<cox.store.Table at 0x7f66d82ecd68>

# Experiment
Run experiment, where the truncation parameter C is varied over the range \[-5, 1\]. Each time that the experiment is run, we generate new data, and re-run the same procedure. All results from the experiment are stored within the store that is initialized in the previous cell.

In [None]:
# regression parameters
num_samples, dims = 10000, 10
noise_var = Tensor([5.0])
W = ch.ones(dims, 1)
W0 = ch.ones(1, 1)

# perform each experiment a total of 10 times
for iter_ in range(10):        
    for c in range(-5, 2):
        # generate data
        X = MultivariateNormal(ch.zeros(dims), ch.eye(dims)/dims).sample(ch.Size([num_samples]))
        y = X.mm(W) + W0 + Normal(ch.zeros(1), ch.sqrt(noise_var)).sample(ch.Size([num_samples]))
        # truncate
        phi = Left(Tensor([c]))
        indices = phi(y).nonzero(as_tuple=False).flatten()
        y_trunc, x_trunc = y[indices], X[indices]

        # experiment parameters
        args = Parameters({ 
            'alpha': Tensor([y_trunc.size(0)/num_samples]), 
            'phi': phi, 
            'epochs': 50,
            'num_workers': 2, 
            'batch_size': 100,
            'bias': True,
            'num_samples': 100,
            'clamp': True, 
            'radius': 2.0, 
            'var_lr': 1e-1,
            'lr': 1e-1,
        })


        # dataset 
        data = TruncatedRegressionDataset(x_trunc, y_trunc, bias=args.bias, unknown=True)
        S = DataLoader(data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
        
        trunc_reg = truncated_regression(phi=phi, alpha=args.alpha, epochs=args.epochs, var_lr=args.var_lr, lr=args.lr)

        results = trunc_reg.fit(S)
        var_ = results.lambda_.inverse().detach().cpu()
        w = results.v.detach().cpu()*var_
        w0 = results.bias.detach().cpu()*var_
        # calculate metrics 
        real_params = ch.cat([W, W0])
        ols_params = ch.cat([data.w, data.w0.unsqueeze(0)])
        delphi_params = ch.cat([w, w0])
        delphi_param_mse = mean_squared_error(delphi_params, real_params)
        delphi_var_mse = mean_squared_error(var_, noise_var)
        ols_param_mse = mean_squared_error(ols_params, real_params)
        ols_var_mse = mean_squared_error(noise_var, data.lambda_.inverse())

        store[STORE_TABLE_NAME].append_row({ 
            'delphi_param_mse': delphi_param_mse,
            'delphi_var_mse': delphi_var_mse, 
            'var': float(noise_var), 
            'ols_param_mse': ols_param_mse,
            'ols_var_mse': ols_var_mse,
            'alpha': float(args.alpha.flatten()),
            'c': c
        })

Epoch:1 | Loss 3.9093 | Train1 0.166 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 58.38it/s]
Epoch:2 | Loss 11.0840 | Train1 0.163 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 61.23it/s]
Epoch:3 | Loss 5.7313 | Train1 0.164 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 58.98it/s]
Epoch:4 | Loss 4.7278 | Train1 0.162 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 59.42it/s]
Epoch:5 | Loss 3.6274 | Train1 0.165 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 60.78it/s]
Epoch:6 | Loss 2.8357 | Train1 0.168 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 57.73it/s]
Epoch:7 | Loss 3.8323 | Train1 0.166 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 57.04it/s]
Epoch:8 | Loss 5.5009 | Train1 0.169 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 100/100 [00:01<00:00, 56.27it/s]
Epoch:9 | Loss 

Epoch:17 | Loss 3.1538 | Train1 0.168 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 58.78it/s]
Epoch:18 | Loss 8.1114 | Train1 0.165 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 56.30it/s] 
Epoch:19 | Loss 15.2329 | Train1 0.168 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 55.33it/s]
Epoch:20 | Loss 2.8652 | Train1 0.170 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 57.93it/s]
Epoch:21 | Loss 10.3504 | Train1 0.166 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 56.47it/s]
Epoch:22 | Loss 2.7409 | Train1 0.171 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 56.04it/s]
Epoch:23 | Loss 2.6113 | Train1 0.169 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 56.87it/s]
Epoch:24 | Loss 3.9410 | Train1 0.168 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 99/99 [00:01<00:00, 55.79it/s]
Epoch:25 | Loss 3.027

Epoch:34 | Loss 2.4999 | Train1 0.178 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.30it/s]
Epoch:35 | Loss 2.4622 | Train1 0.181 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.50it/s]
Epoch:36 | Loss 2.7504 | Train1 0.177 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.25it/s]
Epoch:37 | Loss 2.4758 | Train1 0.180 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 66.57it/s]
Epoch:38 | Loss 2.2938 | Train1 0.176 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 55.55it/s]
Epoch:39 | Loss 2.3628 | Train1 0.178 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 56.21it/s]
Epoch:40 | Loss 2.5387 | Train1 0.180 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 55.43it/s]
Epoch:41 | Loss 2.3587 | Train1 0.178 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 95/95 [00:01<00:00, 55.90it/s]
Epoch:42 | Loss 2.3407 |

Epoch:1 | Loss 17.5832 | Train1 0.200 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 53.09it/s]
Epoch:2 | Loss 3.9543 | Train1 0.201 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 53.38it/s]
Epoch:3 | Loss 6.1216 | Train1 0.204 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 54.16it/s] 
Epoch:4 | Loss 16.6690 | Train1 0.202 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 53.78it/s]
Epoch:5 | Loss 2.9257 | Train1 0.205 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 52.83it/s]
Epoch:6 | Loss 11.5443 | Train1 0.197 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 53.14it/s]
Epoch:7 | Loss 12.5618 | Train1 0.200 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 52.64it/s]
Epoch:8 | Loss 17.5753 | Train1 0.202 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 80/80 [00:01<00:00, 53.78it/s]
Epoch:9 | Loss 2.6460 | Tr

Epoch:18 | Loss 2.6740 | Train1 0.246 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 65.40it/s]
Epoch:19 | Loss 2.1346 | Train1 0.252 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 50.07it/s]
Epoch:20 | Loss 2.7681 | Train1 0.242 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 52.33it/s]
Epoch:21 | Loss 14.3121 | Train1 0.227 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 53.96it/s]
Epoch:22 | Loss 3.5248 | Train1 0.239 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 55.48it/s]
Epoch:23 | Loss 2.1196 | Train1 0.245 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 51.35it/s]
Epoch:24 | Loss 2.4722 | Train1 0.254 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 50.46it/s]
Epoch:25 | Loss 3.7211 | Train1 0.241 | Train5 -1.000 | Reg term: 0.0 ||: 100%|██████████| 67/67 [00:01<00:00, 50.55it/s]
Epoch:26 | Loss 2.0723 

In [None]:
results = store[STORE_TABLE_NAME].df
results.head()

# Plot Results

In [None]:
# plot results for regression parameter MSE
sns.lineplot(data=results, x='c', y='delphi_param_mse', label='delphi', color='blue')
ax = sns.lineplot(data=results, x='c', y='ols_param_mse', label='ols', color='red')
ax.set(xlabel='Truncation Parameter C', ylabel='MSE')
plt.show()

In [None]:
# plot results for regression noise variance MSE
sns.lineplot(data=results, x='c', y='delphi_var_mse', label='delphi', color="blue")
ax = sns.lineplot(data=results, x='c', y='ols_var_mse', label='ols', color="red")
ax.set(xlabel='Truncation Parameter C', ylabel='MSE')
plt.show()