In [None]:
import argparse

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.neighbors import KernelDensity
from tqdm.notebook import tqdm

import data
import pytorch_lightning as pl
from ddlk import ddlk, hrt, mdn, swap, utils
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

%matplotlib inline

Set random seed for reproducibility

In [None]:
pl.trainer.seed_everything(42)

In [None]:
# Get GPUs. Uses GPU 0 by default
num_gpus = torch.cuda.device_count()
gpus = [0] if num_gpus > 0 else None

# Data

Create PyTorch data loaders

In [None]:
data_args = argparse.Namespace(dataset='gaussian_autoregressive_mixture',
                          n_rel=2,
                          signal_a=100,
                          d=10,
                          k=3,
                          n=2000,
                          rep=0,
                          batch_size=64)

In [None]:
# get data
trainloader, valloader, testloader = data.get_data(data_args)

# Fit DDLK

In [None]:
((X_mu, ), (X_sigma, )) = utils.get_two_moments(trainloader)

Fit `q_joint`

In [None]:
hparams = argparse.Namespace(X_mu=X_mu, X_sigma=X_sigma)

In [None]:
q_joint = mdn.MDNJoint(hparams)
trainer = pl.Trainer(max_epochs=50, num_sanity_val_steps=1, weights_summary=None, deterministic=True, gpus=gpus)
trainer.fit(q_joint,
            train_dataloader=trainloader,
            val_dataloaders=[valloader])

Fit `q_knockoff`

In [None]:
hparams = argparse.Namespace(X_mu=X_mu, X_sigma=X_sigma)

q_knockoff = ddlk.DDLK(hparams, q_joint=q_joint)

In [None]:
trainer = pl.Trainer(max_epochs=100,
                     num_sanity_val_steps=1,
                     deterministic=True,
                     gradient_clip_val=0.5,
                     weights_summary=None, gpus=gpus)

In [None]:
trainer.fit(q_knockoff,
                train_dataloader=trainloader,
                val_dataloaders=[valloader])

# Sample knockoffs

In [None]:
xTr, = utils.extract_data(trainloader)
xTr = xTr.numpy()

In [None]:
with torch.no_grad():
    xTr_tilde = q_knockoff.sample(torch.tensor(xTr)).cpu().numpy()

In [None]:
# select 2 coordinates at random
j1, j2 = np.random.permutation(10)[:2]

In [None]:
kde_data = KernelDensity(bandwidth=6)
kde_data.fit(xTr[:, [j1, j2]])

kde_ddlk = KernelDensity(bandwidth=6)
kde_ddlk.fit(xTr_tilde[:, [j1, j2]])

In [None]:
xmin, xmax = -15, 55
ymin, ymax = -15, 55

# Peform kernel density estimate
A, B = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([A.ravel(), B.ravel()])

fig, axarr = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True)

# plot data
f = np.exp(np.reshape(kde_data.score_samples(positions.T), A.shape))
cfset = axarr[0].contourf(A, B, f, cmap='Blues')
cset = axarr[0].contour(A, B, f, colors='k')
axarr[0].clabel(cset, inline=1, fontsize=10)

# plot ddlk sampled knockoffs
f = np.exp(np.reshape(kde_ddlk.score_samples(positions.T), A.shape))
cfset = axarr[1].contourf(A, B, f, cmap='Reds')
cset = axarr[1].contour(A, B, f, colors='k')
axarr[1].clabel(cset, inline=1, fontsize=10)

axarr[0].set_title('data')
axarr[1].set_title('ddlk')
plt.show()

# Perform variable selection

In [None]:
# set dataloaders to prediction mode
trainloader.dataset.set_mode('prediction')
valloader.dataset.set_mode('prediction')
testloader.dataset.set_mode('prediction')

# extract training and validation data
xTr, yTr = utils.extract_data(trainloader)
xVal, yVal = utils.extract_data(valloader)
## concatenate xTr and xVal to use in HRT
xTr = torch.cat([xTr, xVal], axis=0)
yTr = torch.cat([yTr, yVal], axis=0)
xTr = xTr.float()
xVal = xVal.float()

# extract test data
xTe, yTe = utils.extract_data(testloader)
xTe = xTe.float()

with torch.no_grad():
    xTr_tilde = q_knockoff.sample(xTr).cpu().numpy()
    xTe_tilde = q_knockoff.sample(xTe).cpu().numpy()

knockoff_test = hrt.HRT_Knockoffs(mixture_prop=0.5)

knockoff_test.fit(xTr, yTr, xTr_tilde, tqdm=tqdm)

In [None]:
knockoff_statistics = knockoff_test.score(xTe, yTe, xTe_tilde, tqdm=tqdm)

In [None]:
knockoff_statistics = pd.Series(knockoff_statistics)
results = pd.DataFrame(knockoff_statistics, columns=['statistic']).join(
    pd.DataFrame(trainloader.dataset.beta.flatten(),
                 index=np.arange(trainloader.dataset.beta.flatten().shape[0]),
                 columns=['beta']))
results.index.name = 'feature'

In [None]:
results