In [None]:
%load_ext autoreload
%autoreload 2
# %matplotlib notebook

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
import pickle
import json
import torch
from tqdm import tqdm, tqdm_notebook

from bullseye import BullseyeData

# matplotlib and seaborn settings
from matplotlib import rc
rc('text', usetex=True)
plt.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

sns.set_style("white")

font = {'family': 'serif',
        'color':  'black',
        'weight': 'normal',
        'size': 14}

## Plot 2D Bullseye Data samples 

In [None]:
datagen = BullseyeData(500, 0.075, copies=1, scale_4=True)

fig,ax = plt.subplots(figsize=(4,4))
sns.despine(left=True, bottom=True)
ax.scatter(datagen.X[:,0], datagen.X[:,1])
x0,x1 = ax.get_xlim()
y0,y1 = ax.get_ylim()
ax.set_aspect(abs(x1-x0)/abs(y1-y0))
ax.set_xlabel(r" ", fontdict=font)
ax.set_ylabel(r" ", fontdict=font)
plt.tight_layout()
plt.savefig("assets/bullseye.pdf")
plt.show()

### Mappings for $I(X;Y)$: vanilla vs. proposed for fixed $n$

In [None]:
n_fixed = 1000
datagen = BullseyeData(n_fixed, 0.075, copies=1, scale_4=True)
datagen.make_X_data('data/bullseye_fixedn.h5')

In [None]:
# Vanilla model
config_path = "config/config_bullseye_vanilla.json"
vanilla_manager = ModelManager(config_path)
vanilla_manager.train()

In [None]:
# Proposed Method
config_path = "config/config_bullseye_proposed.json"
proposed_manager = ModelManager(config_path)
proposed_manager.train()

In [None]:
import scipy.interpolate
eps = 0.05
save_file = "checkpoint-epoch150.pth"

def forceAspect(ax,aspect=1):
    im = ax.get_images()
    extent =  im[0].get_extent()
    ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect)

# estimate for R
with h5py.File("data/bullseye_fixedn.h5", "r") as f:
    x_data = np.array(f['X'])
    y_data = np.array(f['Y'])

# vanilla samples
x_margin,y_margin = 0.6, 0.6
_,z_vanilla = vanilla_manager.process_numpy(x_data,  save_file)
with torch.no_grad():
    # determine plot range
    xmax,xmin = np.max(z_vanilla[:,0])+x_margin, np.min(z_vanilla[:,0])-x_margin
    ymax,ymin = np.max(z_vanilla[:,1])+y_margin, np.min(z_vanilla[:,1])-y_margin

    x_samples = xmax + (xmin-xmax) * torch.rand(5000)
    y_samples = ymax + (ymin-ymax) * torch.rand(5000)
    z_samples = (torch.stack((x_samples, y_samples)).t())
    z_samples = z_samples.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    pred_vanilla = vanilla_manager.model._predict(z_samples).detach().cpu().numpy().flatten()
    z_samples = z_samples.detach().cpu().numpy()

xi, yi = np.linspace(xmin+eps, xmax-eps, 300), np.linspace(ymin+eps, ymax-eps, 300)
xi, yi = np.meshgrid(xi, yi)
zi_vanilla = scipy.interpolate.griddata((z_samples[:,0], z_samples[:,1]), pred_vanilla, (xi, yi), method='linear')

fig,(ax1,ax2) = plt.subplots(1,2,figsize=(8,6))
im = ax1.imshow(zi_vanilla, origin='lower', extent=[xmin, xmax, ymin, ymax], cmap='inferno', alpha=0.7)
fig.colorbar(im, fraction=0.046, pad=0.04, ax=ax1)
ax1.scatter(z_vanilla[:,0], z_vanilla[:,1], c='k', s=20, marker='.')
forceAspect(ax1, aspect=1)

# plt.tight_layout()
# ax1.savefig("bullseye_vanilla_z.pdf")
# plt.show()


save_file = "checkpoint-epoch150.pth"
####################################### proposed samples
x_margin,y_margin = 0.6,0.6
_,z_prop = proposed_manager.process_numpy(x_data, save_file)
with torch.no_grad():
    # determine plot range
    xmax,xmin = np.max(z_prop[:,0])+x_margin, np.min(z_prop[:,0])-x_margin
    ymax,ymin = np.max(z_prop[:,1])+y_margin, np.min(z_prop[:,1])-y_margin

    x_samples = xmax + (xmin-xmax) * torch.rand(5000)
    y_samples = ymax + (ymin-ymax) * torch.rand(5000)
    z_samples = (torch.stack((x_samples, y_samples)).t())
    z_samples = z_samples.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    pred_proposed = proposed_manager.model._predict(z_samples).detach().cpu().numpy().flatten()
    z_samples = z_samples.detach().cpu().numpy()

xi, yi = np.linspace(xmin+eps, xmax-eps, 500), np.linspace(ymin+eps, ymax-eps, 500)
xi, yi = np.meshgrid(xi, yi)
zi_proposed = scipy.interpolate.griddata((z_samples[:,0], z_samples[:,1]), pred_proposed, (xi, yi), method='linear')

im = ax2.imshow(zi_proposed, origin='lower', extent=[xmin, xmax, ymin, ymax], cmap='inferno', alpha=0.7)
fig.colorbar(im, fraction=0.046, pad=0.04, ax=ax2)
ax2.scatter(z_prop[:,0], z_prop[:,1], c='k', s=20, marker='.')
forceAspect(ax2, aspect=1)

plt.tight_layout()
# save figure
plt.savefig("bullseye_proposed_z.pdf")
plt.savefig("bullseye_vanilla_z.pdf")

plt.show()